diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9113cf286b..0bd9548bda 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -3,33 +3,84 @@ aws: description: | Authenticate using an Access Key ID and Secret Access Key included in the metadata metadata: + - name: region + type: string + required: false + description: | + The AWS Region where the AWS resource is deployed to. + This will be marked required in Dapr 1.17. + example: '"us-east-1"' - name: awsRegion type: string - required: true + required: false description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'region' instead. The AWS Region where the AWS resource is deployed to. example: '"us-east-1"' - name: accessKey description: AWS access key associated with an IAM account - required: true + required: false sensitive: true example: '"AKIAIOSFODNN7EXAMPLE"' - name: secretKey description: The secret key associated with the access key - required: true + required: false sensitive: true example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' - name: sessionToken + type: string required: false sensitive: true description: | AWS session token to use. A session token is only required if you are using temporary security credentials. example: '"TOKEN"' + - title: "AWS: Assume IAM Role" + description: | + Assume a specific IAM role. Note: This is only supported for Kafka and PostgreSQL. + metadata: + - name: region type: string + required: true + description: | + The AWS Region where the AWS resource is deployed to. + example: '"us-east-1"' + - name: assumeRoleArn + type: string + required: false + description: | + IAM role that has access to AWS resource. + This is another option to authenticate with MSK and RDS Aurora aside from the AWS Credentials. + This will be marked required in Dapr 1.17. + example: '"arn:aws:iam::123456789:role/mskRole"' + - name: sessionName + type: string + description: | + The session name for assuming a role. + example: '"MyAppSession"' + default: '"DaprDefaultSession"' - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment - + - title: "AWS: IAM Roles Anywhere" + description: Use X.509 certificates to establish trust between your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. + metadata: + - name: trustAnchorArn + description: | + ARN of the AWS Trust Anchor in the AWS account granting trust to the Dapr Certificate Authority. + example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901 + required: true + - name: trustProfileArn + description: | + ARN of the AWS IAM Profile in the trusting AWS account. + example: arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901 + required: true + - name: assumeRoleArn + description: | + ARN of the AWS IAM role to assume in the trusting AWS account. + example: arn:aws:iam:012345678910:role/exampleIAMRoleName + required: true + azuread: - title: "Azure AD: Managed identity" description: Authenticate using Azure AD and a managed identity. diff --git a/.build-tools/pkg/metadataschema/builtin-authentication-profiles.go b/.build-tools/pkg/metadataschema/builtin-authentication-profiles.go index 28d6659188..f0971e8df9 100644 --- a/.build-tools/pkg/metadataschema/builtin-authentication-profiles.go +++ b/.build-tools/pkg/metadataschema/builtin-authentication-profiles.go @@ -32,14 +32,38 @@ func ParseBuiltinAuthenticationProfile(bi BuiltinAuthenticationProfile, componen for i, profile := range profiles { res[i] = profile - res[i].Metadata = mergedMetadata(bi.Metadata, res[i].Metadata...) + // deep copy the metadata slice to avoid side effects when manually updating some req -> non-req fields to deprecate some fields for kafka/postgres + // TODO: rm all of this manipulation in Dapr 1.17!! + originalMetadata := profile.Metadata + metadataCopy := make([]Metadata, len(originalMetadata)) + copy(metadataCopy, originalMetadata) - // If component is PostgreSQL, filter out duplicated aws profile fields - if strings.ToLower(componentTitle) == "postgresql" && bi.Name == "aws" { - res[i].Metadata = filterOutDuplicateFields(res[i].Metadata) + if componentTitle == "Apache Kafka" || strings.ToLower(componentTitle) == "postgresql" { + removeRequiredOnSomeAWSFields(&metadataCopy) } + merged := mergedMetadata(bi.Metadata, metadataCopy...) + + // Note: We must apply the removal of deprecated fields after the merge!! + + // Here, we remove some deprecated fields as we support the transition to a new auth profile + if profile.Title == "AWS: Assume IAM Role" && componentTitle == "Apache Kafka" || profile.Title == "AWS: Assume IAM Role" && strings.ToLower(componentTitle) == "postgresql" { + merged = removeSomeDeprecatedFieldsOnUnrelatedAuthProfiles(merged) + } + + // Here, there are no metadata fields that need deprecating + if profile.Title == "AWS: Credentials from Environment Variables" && componentTitle == "Apache Kafka" || profile.Title == "AWS: Credentials from Environment Variables" && strings.ToLower(componentTitle) == "postgresql" { + merged = removeAllDeprecatedFieldsOnUnrelatedAuthProfiles(merged) + } + + // Here, this is a new auth profile, so rm all deprecating fields as unrelated. + if profile.Title == "AWS: IAM Roles Anywhere" && componentTitle == "Apache Kafka" || profile.Title == "AWS: IAM Roles Anywhere" && strings.ToLower(componentTitle) == "postgresql" { + merged = removeAllDeprecatedFieldsOnUnrelatedAuthProfiles(merged) + } + + res[i].Metadata = merged } + return res, nil } @@ -54,26 +78,55 @@ func mergedMetadata(base []Metadata, add ...Metadata) []Metadata { return res } -// filterOutDuplicateFields removes specific duplicated fields from the metadata -func filterOutDuplicateFields(metadata []Metadata) []Metadata { - duplicateFields := map[string]int{ - "awsRegion": 0, - "accessKey": 0, - "secretKey": 0, +// removeRequiredOnSomeAWSFields needs to be removed in Dapr 1.17 as duplicated AWS IAM fields get removed, +// and we standardize on these fields. +// Currently, there are: awsAccessKey, accessKey and awsSecretKey, secretKey, and awsRegion and region fields. +// We normally have accessKey, secretKey, and region fields marked required as it is part of the builtin AWS auth profile fields. +// However, as we rm the aws prefixed ones, we need to then mark the normally required ones as not required only for postgres and kafka. +// This way we do not break existing users, and transition them to the standardized fields. +func removeRequiredOnSomeAWSFields(metadata *[]Metadata) { + if metadata == nil { + return } - filteredMetadata := []Metadata{} + for i := range *metadata { + field := &(*metadata)[i] + + if field == nil { + continue + } + if field.Name == "accessKey" || field.Name == "secretKey" || field.Name == "region" { + field.Required = false + } + } +} + +func removeAllDeprecatedFieldsOnUnrelatedAuthProfiles(metadata []Metadata) []Metadata { + filteredMetadata := []Metadata{} for _, field := range metadata { - if _, exists := duplicateFields[field.Name]; !exists { + if strings.HasPrefix(field.Name, "aws") { + continue + } else { filteredMetadata = append(filteredMetadata, field) + } + } + + return filteredMetadata +} + +func removeSomeDeprecatedFieldsOnUnrelatedAuthProfiles(metadata []Metadata) []Metadata { + filteredMetadata := []Metadata{} + + for _, field := range metadata { + // region is required in Assume Role auth profile, so this is needed for now. + if field.Name == "region" { + field.Required = true + } + if field.Name == "awsAccessKey" || field.Name == "awsSecretKey" || field.Name == "awsSessionToken" || field.Name == "awsRegion" { + continue } else { - if field.Name == "awsRegion" && duplicateFields["awsRegion"] == 0 { - filteredMetadata = append(filteredMetadata, field) - duplicateFields["awsRegion"]++ - } else if field.Name != "awsRegion" { - continue - } + filteredMetadata = append(filteredMetadata, field) } } diff --git a/.github/scripts/dapr_bot.js b/.github/scripts/dapr_bot.js index 279799958a..328972d27e 100644 --- a/.github/scripts/dapr_bot.js +++ b/.github/scripts/dapr_bot.js @@ -21,6 +21,7 @@ const owners = [ 'RyanLettieri', 'shivamkm07', 'shubham1172', + 'sicoyle', 'skyao', 'Taction', 'tmacam', diff --git a/.github/workflows/certification.yml b/.github/workflows/certification.yml.disabled similarity index 100% rename from .github/workflows/certification.yml rename to .github/workflows/certification.yml.disabled diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index bd882e7b55..2096f22433 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -31,9 +31,9 @@ import ( // DynamoDB allows performing stateful operations on AWS DynamoDB. type DynamoDB struct { - client *dynamodb.DynamoDB - table string - logger logger.Logger + authProvider awsAuth.Provider + table string + logger logger.Logger } type dynamoDBMetadata struct { @@ -51,18 +51,27 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing for DynamoDB. -func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error { +func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err } - client, err := d.getClient(meta) + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - d.client = client + d.authProvider = provider d.table = meta.Table return nil @@ -84,7 +93,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - _, err = d.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) @@ -105,16 +114,6 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat return &meta, nil } -func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := dynamoDBMetadata{} @@ -123,5 +122,8 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { + if d.authProvider != nil { + return d.authProvider.Close() + } return nil } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index dbe0ceb918..bf684f8bbb 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "github.com/vmware/vmware-go-kcl/clientlibrary/config" "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces" "github.com/vmware/vmware-go-kcl/clientlibrary/worker" @@ -40,15 +39,16 @@ import ( // AWSKinesis allows receiving and sending data to/from AWS Kinesis stream. type AWSKinesis struct { - client *kinesis.Kinesis - metadata *kinesisMetadata + authProvider awsAuth.Provider + metadata *kinesisMetadata - worker *worker.Worker - workerConfig *config.KinesisClientLibConfiguration + worker *worker.Worker - streamARN *string - consumerARN *string - logger logger.Logger + streamName string + consumerName string + consumerARN *string + logger logger.Logger + consumerMode string closed atomic.Bool closeCh chan struct{} @@ -112,30 +112,25 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - client, err := a.getClient(m) - if err != nil { - return err - } + a.consumerMode = m.KinesisConsumerMode + a.streamName = m.StreamName + a.consumerName = m.ConsumerName + a.metadata = m - streamName := aws.String(m.StreamName) - stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ - StreamName: streamName, - }) + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - if m.KinesisConsumerMode == SharedThroughput { - kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName, - m.StreamName, m.Region, m.ConsumerName, - client.Config.Credentials) - a.workerConfig = kclConfig - } - - a.streamARN = stream.StreamDescription.StreamARN - a.metadata = m - a.client = client - + a.authProvider = provider return nil } @@ -148,7 +143,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - _, err := a.client.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + _, err := a.authProvider.Kinesis().Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -161,16 +156,15 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.closed.Load() { return errors.New("binding is closed") } - if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis().WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = a.client.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = a.authProvider.Kinesis().Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -180,6 +174,10 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } } + stream, err := a.authProvider.Kinesis().Stream(ctx, a.streamName) + if err != nil { + return fmt.Errorf("failed to get kinesis stream arn: %v", err) + } // Wait for context cancelation then stop a.wg.Add(1) go func() { @@ -191,7 +189,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { - a.deregisterConsumer(a.streamARN, a.consumerARN) + a.deregisterConsumer(ctx, stream, a.consumerARN) } }() @@ -226,8 +224,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes return default: } - - sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + sub, err := a.authProvider.Kinesis().Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -269,6 +266,9 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() + if a.authProvider != nil { + return a.authProvider.Close() + } return nil } @@ -276,7 +276,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -288,7 +288,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - consumer, err := a.client.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -307,11 +307,11 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (* return consumer.Consumer.ConsumerARN, nil } -func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) error { +func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error { if a.consumerARN != nil { // Use a background context because the running context may have been canceled already ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err := a.client.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + _, err := a.authProvider.Kinesis().Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -342,7 +342,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - req, _ := a.client.DescribeStreamConsumerRequest(inCpy) + req, _ := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) @@ -354,16 +354,6 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des return w.WaitWithContext(ctx) } -func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - k := kinesis.New(sess) - - return k, nil -} - func (a *AWSKinesis) parseMetadata(meta bindings.Metadata) (*kinesisMetadata, error) { var m kinesisMetadata err := kitmd.DecodeMetadata(meta.Properties, &m) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index cc67cec94f..fa20c70a6b 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,9 +29,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/google/uuid" @@ -61,11 +59,9 @@ const ( // AWSS3 is a binding for an AWS S3 storage bucket. type AWSS3 struct { - metadata *s3Metadata - s3Client *s3.S3 - uploader *s3manager.Uploader - downloader *s3manager.Downloader - logger logger.Logger + metadata *s3Metadata + authProvider awsAuth.Provider + logger logger.Logger } type s3Metadata struct { @@ -109,23 +105,11 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } -// Init does metadata parsing and connection creation. -func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { - m, err := s.parseMetadata(metadata) - if err != nil { - return err - } - session, err := s.getSession(m) - if err != nil { - return err - } - - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) +func (s *AWSS3) getAWSConfig(opts awsAuth.Options) *aws.Config { + cfg := awsAuth.GetConfig(opts).WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { + if s.metadata.InsecureSSL { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = &tls.Config{ //nolint:gosec @@ -138,16 +122,40 @@ func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") } + return cfg +} +// Init does metadata parsing and connection creation. +func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { + m, err := s.parseMetadata(metadata) + if err != nil { + return err + } s.metadata = m - s.s3Client = s3.New(session, cfg) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) + + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, s.getAWSConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider return nil } func (s *AWSS3) Close() error { + if s.authProvider != nil { + return s.authProvider.Close() + } return nil } @@ -201,8 +209,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi if metadata.StorageClass != "" { storageClass = aws.String(metadata.StorageClass) } - - resultUpload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + resultUpload, err := s.authProvider.S3().Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -215,7 +222,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi var presignURL string if metadata.PresignTTL != "" { - url, presignErr := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, presignErr := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if presignErr != nil { return nil, fmt.Errorf("s3 binding error: %s", presignErr) } @@ -255,7 +262,7 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataPresignTTL) } - url, err := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, err := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if err != nil { return nil, fmt.Errorf("s3 binding error: %w", err) } @@ -272,13 +279,12 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind }, nil } -func (s *AWSS3) presignObject(bucket, key, ttl string) (string, error) { +func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (string, error) { d, err := time.ParseDuration(ttl) if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - - objReq, _ := s.s3Client.GetObjectRequest(&s3.GetObjectInput{ + objReq, _ := s.authProvider.S3().S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -302,8 +308,7 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings } buff := &aws.WriteAtBuffer{} - - _, err = s.downloader.DownloadWithContext(ctx, + _, err = s.authProvider.S3().Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -337,8 +342,7 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi if key == "" { return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - - _, err := s.s3Client.DeleteObjectWithContext( + _, err := s.authProvider.S3().S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -367,8 +371,7 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding if payload.MaxResults < 1 { payload.MaxResults = defaultMaxResults } - - result, err := s.s3Client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + result, err := s.authProvider.S3().S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), @@ -415,15 +418,6 @@ func (s *AWSS3) parseMetadata(md bindings.Metadata) (*s3Metadata, error) { return &m, nil } -func (s *AWSS3) getSession(metadata *s3Metadata) (*session.Session, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return sess, nil -} - // Helper to merge config and request metadata. func (metadata s3Metadata) mergeWithRequestMetadata(req *bindings.InvokeRequest) (s3Metadata, error) { merged := metadata diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 483fde8c64..b8d2ff3faa 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -38,9 +38,9 @@ const ( // AWSSES is an AWS SNS binding. type AWSSES struct { - metadata *sesMetadata - logger logger.Logger - svc *ses.SES + authProvider awsAuth.Provider + metadata *sesMetadata + logger logger.Logger } type sesMetadata struct { @@ -61,19 +61,29 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { // Parse input metadata - meta, err := a.parseMetadata(metadata) + m, err := a.parseMetadata(metadata) if err != nil { return err } - svc, err := a.getClient(meta) + a.metadata = m + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.metadata = meta - a.svc = svc + a.authProvider = provider return nil } @@ -141,7 +151,7 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - result, err := a.svc.SendEmail(input) + result, err := a.authProvider.Ses().Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } @@ -158,18 +168,6 @@ func (metadata sesMetadata) mergeWithRequestMetadata(req *bindings.InvokeRequest return merged } -func (a *AWSSES) getClient(metadata *sesMetadata) (*ses.SES, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, fmt.Errorf("SES binding error: error creating AWS session %w", err) - } - - // Create an SES instance - svc := ses.New(sess) - - return svc, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMap) { metadataStruct := sesMetadata{} @@ -178,5 +176,8 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { + if a.authProvider != nil { + return a.authProvider.Close() + } return nil } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 43b63cd2b1..370cabdf25 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -30,8 +30,8 @@ import ( // AWSSNS is an AWS SNS binding. type AWSSNS struct { - client *sns.SNS - topicARN string + authProvider awsAuth.Provider + topicARN string logger logger.Logger } @@ -43,6 +43,7 @@ type snsMetadata struct { SessionToken string `json:"sessionToken" mapstructure:"sessionToken" mdignore:"true"` TopicArn string `json:"topicArn"` + // TODO: in Dapr 1.17 rm the alias on region as we remove the aws prefix on these fields Region string `json:"region" mapstructure:"region" mapstructurealiases:"awsRegion" mdignore:"true"` Endpoint string `json:"endpoint"` } @@ -58,16 +59,27 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err } - client, err := a.getClient(m) + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.client = client + a.authProvider = provider a.topicARN = m.TopicArn return nil @@ -83,16 +95,6 @@ func (a *AWSSNS) parseMetadata(meta bindings.Metadata) (*snsMetadata, error) { return &m, nil } -func (a *AWSSNS) getClient(metadata *snsMetadata) (*sns.SNS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sns.New(sess) - - return c, nil -} - func (a *AWSSNS) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } @@ -107,7 +109,7 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - _, err = a.client.PublishWithContext(ctx, &sns.PublishInput{ + _, err = a.authProvider.Sns().Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, @@ -127,5 +129,8 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { + if a.authProvider != nil { + return a.authProvider.Close() + } return nil } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 465e061b61..b09fde61f6 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -33,13 +33,12 @@ import ( // AWSSQS allows receiving and sending data to/from AWS SQS. type AWSSQS struct { - Client *sqs.SQS - QueueURL *string - - logger logger.Logger - wg sync.WaitGroup - closeCh chan struct{} - closed atomic.Bool + authProvider awsAuth.Provider + queueName string + logger logger.Logger + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type sqsMetadata struct { @@ -66,21 +65,22 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - client, err := a.getClient(m) - if err != nil { - return err + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } - - queueName := m.QueueName - resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - a.QueueURL = resultURL.QueueUrl - a.Client = client + a.authProvider = provider + a.queueName = m.QueueName return nil } @@ -91,9 +91,14 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - _, err := a.Client.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } + + _, err = a.authProvider.Sqs().Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, - QueueUrl: a.QueueURL, + QueueUrl: url, }) return nil, err @@ -113,9 +118,13 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } - result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ - QueueUrl: a.QueueURL, + result, err := a.authProvider.Sqs().Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", }), @@ -126,7 +135,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { WaitTimeSeconds: aws.Int64(20), }) if err != nil { - a.logger.Errorf("Unable to receive message from queue %q, %v.", *a.QueueURL, err) + a.logger.Errorf("Unable to receive message from queue %q, %v.", url, err) } if len(result.Messages) > 0 { @@ -140,8 +149,8 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - a.Client.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ - QueueUrl: a.QueueURL, + a.authProvider.Sqs().Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: url, ReceiptHandle: msgHandle, }) } @@ -164,6 +173,9 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() + if a.authProvider != nil { + return a.authProvider.Close() + } return nil } @@ -177,16 +189,6 @@ func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) return &m, nil } -func (a *AWSSQS) getClient(metadata *sqsMetadata) (*sqs.SQS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sqs.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSQS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := sqsMetadata{} diff --git a/bindings/azure/eventhubs/metadata.yaml b/bindings/azure/eventhubs/metadata.yaml index 439fc5a9e6..e5272bd62e 100644 --- a/bindings/azure/eventhubs/metadata.yaml +++ b/bindings/azure/eventhubs/metadata.yaml @@ -55,7 +55,15 @@ builtinAuthenticationProfiles: default: "false" example: "false" description: | + Allow management of the Event Hub namespace and storage account. + - name: enableInOrderMessageDelivery + type: bool + required: false + default: "false" + example: "false" + description: | + Enable in order processing of messages within a partition. - name: resourceGroupName type: string required: false @@ -153,3 +161,12 @@ metadata: description: | Storage container name. example: '"myeventhubstoragecontainer"' + - name: getAllMessageProperties + required: false + default: "false" + example: "false" + binding: + input: true + output: false + description: | + When set to true, will retrieve all message properties and include them in the returned event metadata diff --git a/bindings/kafka/metadata.yaml b/bindings/kafka/metadata.yaml index ab24f3e8fe..04667a9edc 100644 --- a/bindings/kafka/metadata.yaml +++ b/bindings/kafka/metadata.yaml @@ -14,6 +14,67 @@ binding: operations: - name: create description: "Publish a new message in the topic." +# This auth profile has duplicate fields intentionally as we maintain backwards compatibility, +# but also move Kafka to utilize the noramlized AWS fields in the builtin auth profiles. +# TODO: rm the duplicate aws prefixed fields in Dapr 1.17. +builtinAuthenticationProfiles: + - name: "aws" + metadata: + - name: authType + type: string + required: true + description: | + Authentication type. + This must be set to "awsiam" for this authentication profile. + example: '"awsiam"' + allowedValues: + - "awsiam" + - name: awsAccessKey + type: string + required: false + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'accessKey' instead. + If both fields are set, then 'accessKey' value will be used. + AWS access key associated with an IAM account. + example: '"AKIAIOSFODNN7EXAMPLE"' + - name: awsSecretKey + type: string + required: false + sensitive: true + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'secretKey' instead. + If both fields are set, then 'secretKey' value will be used. + The secret key associated with the access key. + example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' + - name: awsSessionToken + type: string + sensitive: true + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'sessionToken' instead. + If both fields are set, then 'sessionToken' value will be used. + AWS session token to use. A session token is only required if you are using temporary security credentials. + example: '"TOKEN"' + - name: awsIamRoleArn + type: string + required: false + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'assumeRoleArn' instead. + If both fields are set, then 'assumeRoleArn' value will be used. + IAM role that has access to MSK. This is another option to authenticate with MSK aside from the AWS Credentials. + example: '"arn:aws:iam::123456789:role/mskRole"' + - name: awsStsSessionName + type: string + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'sessionName' instead. + If both fields are set, then 'sessionName' value will be used. + Represents the session name for assuming a role. + example: '"MyAppSession"' + default: '"DaprDefaultSession"' authenticationProfiles: - title: "OIDC Authentication" description: | @@ -139,55 +200,6 @@ authenticationProfiles: example: '"none"' allowedValues: - "none" - - title: "AWS IAM" - description: "Authenticate using AWS IAM credentials or role for AWS MSK" - metadata: - - name: authType - type: string - required: true - description: | - Authentication type. - This must be set to "awsiam" for this authentication profile. - example: '"awsiam"' - allowedValues: - - "awsiam" - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the MSK Kafka broker is deployed to. - example: '"us-east-1"' - - name: awsAccessKey - type: string - required: true - description: | - AWS access key associated with an IAM account. - example: '"AKIAIOSFODNN7EXAMPLE"' - - name: awsSecretKey - type: string - required: true - sensitive: true - description: | - The secret key associated with the access key. - example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' - - name: awsSessionToken - type: string - sensitive: true - description: | - AWS session token to use. A session token is only required if you are using temporary security credentials. - example: '"TOKEN"' - - name: awsIamRoleArn - type: string - required: true - description: | - IAM role that has access to MSK. This is another option to authenticate with MSK aside from the AWS Credentials. - example: '"arn:aws:iam::123456789:role/mskRole"' - - name: awsStsSessionName - type: string - description: | - Represents the session name for assuming a role. - example: '"MyAppSession"' - default: '"MSKSASLDefaultSession"' metadata: - name: topics type: string diff --git a/bindings/postgres/metadata.go b/bindings/postgres/metadata.go index b4747c33ff..44e55a37f3 100644 --- a/bindings/postgres/metadata.go +++ b/bindings/postgres/metadata.go @@ -14,6 +14,7 @@ limitations under the License. package postgres import ( + "errors" "time" "github.com/dapr/components-contrib/common/authentication/aws" @@ -27,7 +28,7 @@ const ( type psqlMetadata struct { pgauth.PostgresAuthMetadata `mapstructure:",squash"` - aws.AWSIAM `mapstructure:",squash"` + aws.DeprecatedPostgresIAM `mapstructure:",squash"` Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"` } @@ -53,5 +54,9 @@ func (m *psqlMetadata) InitWithMetadata(meta map[string]string) error { return err } + if m.Timeout < 1*time.Second { + return errors.New("invalid value for 'timeout': must be greater than 1s") + } + return nil } diff --git a/bindings/postgres/metadata.yaml b/bindings/postgres/metadata.yaml index 6a9908a038..e702a82c64 100644 --- a/bindings/postgres/metadata.yaml +++ b/bindings/postgres/metadata.yaml @@ -56,23 +56,21 @@ builtinAuthenticationProfiles: example: | "host=mydb.postgres.database.aws.com user=myapplication port=5432 dbname=dapr_test sslmode=require" type: string - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the AWS Relational Database Service is deployed to. - example: '"us-east-1"' - name: awsAccessKey type: string - required: true + required: false description: | + Deprecated as of Dapr 1.17. Use 'accessKey' instead if using AWS IAM. + If both fields are set, then 'accessKey' value will be used. AWS access key associated with an IAM account. example: '"AKIAIOSFODNN7EXAMPLE"' - name: awsSecretKey type: string - required: true + required: false sensitive: true description: | + Deprecated as of Dapr 1.17. Use 'secretKey' instead if using AWS IAM. + If both fields are set, then 'secretKey' value will be used. The secret key associated with the access key. example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' authenticationProfiles: diff --git a/bindings/postgres/metadata_test.go b/bindings/postgres/metadata_test.go new file mode 100644 index 0000000000..ece5e433ed --- /dev/null +++ b/bindings/postgres/metadata_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgres + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMetadata(t *testing.T) { + t.Run("missing connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{} + + err := m.InitWithMetadata(props) + require.Error(t, err) + require.ErrorContains(t, err, "connection string") + }) + + t.Run("has connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + }) + + t.Run("default timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 20*time.Second, m.Timeout) + }) + + t.Run("invalid timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "NaN", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) + + t.Run("positive timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "42", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 42*time.Second, m.Timeout) + }) + + t.Run("zero timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "0", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) +} diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index c9dc6bcfbe..c0729cab06 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -26,6 +26,8 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/dapr/components-contrib/bindings" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + pgauth "github.com/dapr/components-contrib/common/authentication/postgresql" "github.com/dapr/components-contrib/metadata" "github.com/dapr/kit/logger" ) @@ -45,6 +47,11 @@ type Postgres struct { logger logger.Logger db *pgxpool.Pool closed atomic.Bool + + enableAzureAD bool + enableAWSIAM bool + + awsAuthProvider awsAuth.Provider } // NewPostgres returns a new PostgreSQL output binding. @@ -59,25 +66,52 @@ func (p *Postgres) Init(ctx context.Context, meta bindings.Metadata) error { if p.closed.Load() { return errors.New("cannot initialize a previously-closed component") } - + opts := pgauth.InitWithMetadataOpts{ + AzureADEnabled: p.enableAzureAD, + AWSIAMEnabled: p.enableAWSIAM, + } m := psqlMetadata{} - err := m.InitWithMetadata(meta.Properties) - if err != nil { + if err := m.InitWithMetadata(meta.Properties); err != nil { return err } + var err error poolConfig, err := m.GetPgxPoolConfig() if err != nil { return err } + if opts.AWSIAMEnabled && m.UseAWSIAM { + opts, validateErr := m.BuildAwsIamOptions(p.logger, meta.Properties) + if validateErr != nil { + return fmt.Errorf("failed to validate AWS IAM authentication fields: %w", validateErr) + } + + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, *opts, awsAuth.GetConfig(*opts)) + if err != nil { + return err + } + p.awsAuthProvider = provider + p.awsAuthProvider.UpdatePostgres(ctx, poolConfig) + } + // This context doesn't control the lifetime of the connection pool, and is // only scoped to postgres creating resources at init. - p.db, err = pgxpool.NewWithConfig(ctx, poolConfig) + connCtx, connCancel := context.WithTimeout(ctx, m.Timeout) + defer connCancel() + p.db, err = pgxpool.NewWithConfig(connCtx, poolConfig) if err != nil { return fmt.Errorf("unable to connect to the DB: %w", err) } + pingCtx, pingCancel := context.WithTimeout(ctx, m.Timeout) + defer pingCancel() + err = p.db.Ping(pingCtx) + if err != nil { + return fmt.Errorf("failed to ping the DB: %w", err) + } + return nil } @@ -177,7 +211,11 @@ func (p *Postgres) Close() error { } p.db = nil - return nil + errs := make([]error, 1) + if p.awsAuthProvider != nil { + errs[0] = p.awsAuthProvider.Close() + } + return errors.Join(errs...) } func (p *Postgres) query(ctx context.Context, sql string, args ...any) (result []byte, err error) { diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go index c24fc099fb..6a517fcd6a 100644 --- a/bindings/postgres/postgres_test.go +++ b/bindings/postgres/postgres_test.go @@ -15,6 +15,7 @@ package postgres import ( "context" + "errors" "fmt" "os" "testing" @@ -62,6 +63,10 @@ func TestPostgresIntegration(t *testing.T) { t.SkipNow() } + t.Run("Test init configurations", func(t *testing.T) { + testInitConfiguration(t, url) + }) + // live DB test b := NewPostgres(logger.NewLogger("test")).(*Postgres) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"connectionString": url}}} @@ -131,6 +136,46 @@ func TestPostgresIntegration(t *testing.T) { }) } +// testInitConfiguration tests valid and invalid config settings. +func testInitConfiguration(t *testing.T, connectionString string) { + logger := logger.NewLogger("test") + tests := []struct { + name string + props map[string]string + expectedErr error + }{ + { + name: "Empty", + props: map[string]string{}, + expectedErr: errors.New("missing connection string"), + }, + { + name: "Valid connection string", + props: map[string]string{"connectionString": connectionString}, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPostgres(logger).(*Postgres) + defer p.Close() + + metadata := bindings.Metadata{ + Base: metadata.Base{Properties: tt.props}, + } + + err := p.Init(context.Background(), metadata) + if tt.expectedErr == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Equal(t, tt.expectedErr, err) + } + }) + } +} + func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { require.NoError(t, err) assert.NotNil(t, res) diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 48c8b209a4..2728249afd 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -15,20 +15,8 @@ package aws import ( "context" - "errors" - "fmt" - "strconv" - "time" - - awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - v2creds "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/dapr/kit/logger" @@ -38,59 +26,85 @@ type EnvironmentSettings struct { Metadata map[string]string } -func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { - optFns := []func(*config.LoadOptions) error{} - if region != "" { - optFns = append(optFns, config.WithRegion(region)) - } +// TODO: Delete in Dapr 1.17 so we can move all IAM fields to use the defaults of: +// accessKey and secretKey and region as noted in the docs, and Options struct above. +type DeprecatedKafkaIAM struct { + Region string `json:"awsRegion" mapstructure:"awsRegion"` + AccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + SecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` + SessionToken string `json:"awsSessionToken" mapstructure:"awsSessionToken"` + IamRoleArn string `json:"awsIamRoleArn" mapstructure:"awsIamRoleArn"` + StsSessionName string `json:"awsStsSessionName" mapstructure:"awsStsSessionName"` +} - if accessKey != "" && secretKey != "" { - provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) - optFns = append(optFns, config.WithCredentialsProvider(provider)) - } +type Options struct { + Logger logger.Logger + Properties map[string]string - awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) - if err != nil { - return awsv2.Config{}, err - } + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - if endpoint != "" { - awsCfg.BaseEndpoint = &endpoint - } + // TODO: in Dapr 1.17 rm the alias on regions as we rm the aws prefixed one. + // Docs have it just as region, but most metadata fields show the aws prefix... + Region string `json:"region" mapstructure:"region" mapstructurealiases:"awsRegion"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` + SessionName string `json:"sessionName" mapstructure:"sessionName"` + AssumeRoleARN string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` + SessionToken string `json:"sessionToken" mapstructure:"sessionToken"` - return awsCfg, nil + Endpoint string } -func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) { - awsConfig := aws.NewConfig() +// TODO: Delete in Dapr 1.17 so we can move all IAM fields to use the defaults of: +// accessKey and secretKey and region as noted in the docs, and Options struct above. +type DeprecatedPostgresIAM struct { + // Access key to use for accessing PostgreSQL. + AccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + // Secret key to use for accessing PostgreSQL. + SecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` +} - if region != "" { - awsConfig = awsConfig.WithRegion(region) - } +func GetConfig(opts Options) *aws.Config { + cfg := aws.NewConfig() - if accessKey != "" && secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)) + switch { + case opts.Region != "": + cfg.WithRegion(opts.Region) + case opts.Endpoint != "": + cfg.WithEndpoint(opts.Endpoint) } - if endpoint != "" { - awsConfig = awsConfig.WithEndpoint(endpoint) - } + return cfg +} - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } +//nolint:interfacebloat +type Provider interface { + S3() *S3Clients + DynamoDB() *DynamoDBClients + Sqs() *SqsClients + Sns() *SnsClients + SnsSqs() *SnsSqsClients + SecretManager() *SecretManagerClients + ParameterStore() *ParameterStoreClients + Kinesis() *KinesisClients + Ses() *SesClients + Kafka(KafkaOptions) (*KafkaClients, error) + + // Postgres is an outlier to the others in the sense that we can update only it's config, + // as we use a max connection time of 8 minutes. + // This means that we can just update the config session credentials, + // and then in 8 minutes it will update to a new session automatically for us. + UpdatePostgres(context.Context, *pgxpool.Config) + + Close() error +} - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), +func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) { + if isX509Auth(opts.Properties) { + return newX509(ctx, opts, cfg) } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - - return awsSession, nil + return newStaticIAM(ctx, opts, cfg) } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. @@ -102,83 +116,13 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } -type AWSIAM struct { - // Ignored by metadata parser because included in built-in authentication profile - // Access key to use for accessing PostgreSQL. - AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` - // Secret key to use for accessing PostgreSQL. - AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` - // AWS region in which PostgreSQL is deployed. - AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` -} - -type AWSIAMAuthOptions struct { - PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - Region string `json:"region" mapstructure:"region"` - AccessKey string `json:"accessKey" mapstructure:"accessKey"` - SecretKey string `json:"secretKey" mapstructure:"secretKey"` -} - -func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, error) { - dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) - var authenticationToken string - - // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.Go.html - // Default to load default config through aws credentials file (~/.aws/credentials) - awsCfg, err := config.LoadDefaultConfig(ctx) - // Note: in the event of an error with invalid config or failed to load config, - // then we fall back to using the access key and secret key. - switch { - case errors.Is(err, config.SharedConfigAssumeRoleError{}.Err), - errors.Is(err, config.SharedConfigLoadError{}.Err), - errors.Is(err, config.SharedConfigProfileNotExistError{}.Err): - // Validate if access key and secret access key are provided - if opts.AccessKey == "" || opts.SecretKey == "" { - return "", fmt.Errorf("failed to load default configuration for AWS using accessKey and secretKey: %w", err) - } - - // Set credentials explicitly - awsCfg := v2creds.NewStaticCredentialsProvider(opts.AccessKey, opts.SecretKey, "") - authenticationToken, err = auth.BuildAuthToken( - ctx, dbEndpoint, opts.Region, opts.PoolConfig.ConnConfig.User, awsCfg) - if err != nil { - return "", fmt.Errorf("failed to create AWS authentication token: %w", err) - } - - return authenticationToken, nil - case err != nil: - return "", errors.New("failed to load default AWS authentication configuration") - } - - authenticationToken, err = auth.BuildAuthToken( - ctx, dbEndpoint, opts.Region, opts.PoolConfig.ConnConfig.User, awsCfg.Credentials) - if err != nil { - return "", fmt.Errorf("failed to create AWS authentication token: %w", err) - } - - return authenticationToken, nil -} - -func (opts *AWSIAMAuthOptions) InitiateAWSIAMAuth() error { - // Set max connection lifetime to 8 minutes in postgres connection pool configuration. - // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, - // while leveraging the BeforeConnect hook to recreate the token in time dynamically. - opts.PoolConfig.MaxConnLifetime = time.Minute * 8 - - // Setup connection pool config needed for AWS IAM authentication - opts.PoolConfig.BeforeConnect = func(ctx context.Context, pgConfig *pgx.ConnConfig) error { - // Manually reset auth token with aws and reset the config password using the new iam token - pwd, errGetAccessToken := opts.GetAccessToken(ctx) - if errGetAccessToken != nil { - return fmt.Errorf("failed to refresh access token for iam authentication with PostgreSQL: %w", errGetAccessToken) +// Coalesce is a helper function to return the first non-empty string from the inputs +// This helps us to migrate away from the deprecated duplicate aws auth profile metadata fields in Dapr 1.17. +func Coalesce(values ...string) string { + for _, v := range values { + if v != "" { + return v } - - pgConfig.Password = pwd - opts.PoolConfig.ConnConfig.Password = pwd - - return nil } - - return nil + return "" } diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go new file mode 100644 index 0000000000..15aac78ad7 --- /dev/null +++ b/common/authentication/aws/aws_test.go @@ -0,0 +1,44 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewEnvironmentSettings(t *testing.T) { + tests := []struct { + name string + metadata map[string]string + }{ + { + name: "valid metadata", + metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewEnvironmentSettings(tt.metadata) + require.NoError(t, err) + assert.NotNil(t, result) + }) + } +} diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go new file mode 100644 index 0000000000..11b26e4988 --- /dev/null +++ b/common/authentication/aws/client.go @@ -0,0 +1,351 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/IBM/sarama" + "github.com/aws/aws-msk-iam-sasl-signer-go/signer" + aws2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type Clients struct { + mu sync.RWMutex + + s3 *S3Clients + Dynamo *DynamoDBClients + sns *SnsClients + sqs *SqsClients + snssqs *SnsSqsClients + Secret *SecretManagerClients + ParameterStore *ParameterStoreClients + kinesis *KinesisClients + ses *SesClients + kafka *KafkaClients +} + +func newClients() *Clients { + return new(Clients) +} + +func (c *Clients) refresh(session *session.Session) error { + c.mu.Lock() + defer c.mu.Unlock() + switch { + case c.s3 != nil: + c.s3.New(session) + case c.Dynamo != nil: + c.Dynamo.New(session) + case c.sns != nil: + c.sns.New(session) + case c.sqs != nil: + c.sqs.New(session) + case c.snssqs != nil: + c.snssqs.New(session) + case c.Secret != nil: + c.Secret.New(session) + case c.ParameterStore != nil: + c.ParameterStore.New(session) + case c.kinesis != nil: + c.kinesis.New(session) + case c.ses != nil: + c.ses.New(session) + case c.kafka != nil: + // Note: we pass in nil for token provider + // as there are no special fields for x509 auth for it. + // Only static auth passes it in. + err := c.kafka.New(session, nil) + if err != nil { + return fmt.Errorf("failed to refresh Kafka AWS IAM Config: %w", err) + } + } + return nil +} + +type S3Clients struct { + S3 *s3.S3 + Uploader *s3manager.Uploader + Downloader *s3manager.Downloader +} + +type DynamoDBClients struct { + DynamoDB dynamodbiface.DynamoDBAPI +} + +type SnsSqsClients struct { + Sns *sns.SNS + Sqs *sqs.SQS + Sts *sts.STS +} + +type SnsClients struct { + Sns *sns.SNS +} + +type SqsClients struct { + Sqs sqsiface.SQSAPI +} + +type SecretManagerClients struct { + Manager secretsmanageriface.SecretsManagerAPI +} + +type ParameterStoreClients struct { + Store ssmiface.SSMAPI +} + +type KinesisClients struct { + Kinesis kinesisiface.KinesisAPI + Region string + Credentials *credentials.Credentials +} + +type SesClients struct { + Ses *ses.SES +} + +type KafkaClients struct { + config *sarama.Config + consumerGroup *string + brokers *[]string + maxMessageBytes *int + + ConsumerGroup sarama.ConsumerGroup + Producer sarama.SyncProducer +} + +func (c *S3Clients) New(session *session.Session) { + refreshedS3 := s3.New(session, session.Config) + c.S3 = refreshedS3 + c.Uploader = s3manager.NewUploaderWithClient(refreshedS3) + c.Downloader = s3manager.NewDownloaderWithClient(refreshedS3) +} + +func (c *DynamoDBClients) New(session *session.Session) { + c.DynamoDB = dynamodb.New(session, session.Config) +} + +func (c *SnsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) +} + +func (c *SnsSqsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) + c.Sqs = sqs.New(session, session.Config) + c.Sts = sts.New(session, session.Config) +} + +func (c *SqsClients) New(session *session.Session) { + c.Sqs = sqs.New(session, session.Config) +} + +func (c *SqsClients) QueueURL(ctx context.Context, queueName string) (*string, error) { + if c.Sqs != nil { + resultURL, err := c.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + QueueName: aws.String(queueName), + }) + if resultURL != nil { + return resultURL.QueueUrl, err + } + } + return nil, errors.New("unable to get queue url due to empty client") +} + +func (c *SecretManagerClients) New(session *session.Session) { + c.Manager = secretsmanager.New(session, session.Config) +} + +func (c *ParameterStoreClients) New(session *session.Session) { + c.Store = ssm.New(session, session.Config) +} + +func (c *KinesisClients) New(session *session.Session) { + c.Kinesis = kinesis.New(session, session.Config) + c.Region = *session.Config.Region + c.Credentials = session.Config.Credentials +} + +func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string, error) { + if c.Kinesis != nil { + stream, err := c.Kinesis.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + }) + if stream != nil { + return stream.StreamDescription.StreamARN, err + } + } + + return nil, errors.New("unable to get stream arn due to empty client") +} + +func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode string) *config.KinesisClientLibConfiguration { + const sharedMode = "shared" + if c.Kinesis != nil { + if mode == sharedMode { + if c.Credentials != nil { + kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, + stream, c.Region, consumer, + c.Credentials) + return kclConfig + } + } + } + + return nil +} + +func (c *SesClients) New(session *session.Session) { + c.Ses = ses.New(session, session.Config) +} + +type KafkaOptions struct { + Config *sarama.Config + ConsumerGroup string + Brokers []string + MaxMessageBytes int +} + +func initKafkaClients(opts KafkaOptions) *KafkaClients { + return &KafkaClients{ + config: opts.Config, + consumerGroup: &opts.ConsumerGroup, + brokers: &opts.Brokers, + maxMessageBytes: &opts.MaxMessageBytes, + } +} + +func (c *KafkaClients) New(session *session.Session, tokenProvider *mskTokenProvider) error { + const timeout = 10 * time.Second + creds, err := session.Config.Credentials.Get() + if err != nil { + return fmt.Errorf("failed to get credentials from session: %w", err) + } + + // fill in token provider common fields across x509 and static auth + if tokenProvider == nil { + tokenProvider = &mskTokenProvider{} + } + tokenProvider.generateTokenTimeout = timeout + tokenProvider.region = *session.Config.Region + tokenProvider.accessKey = creds.AccessKeyID + tokenProvider.secretKey = creds.SecretAccessKey + tokenProvider.sessionToken = creds.SessionToken + + c.config.Net.SASL.Enable = true + c.config.Net.SASL.Mechanism = sarama.SASLTypeOAuth + c.config.Net.SASL.TokenProvider = tokenProvider + + _, err = c.config.Net.SASL.TokenProvider.Token() + if err != nil { + return fmt.Errorf("error validating iam credentials %v", err) + } + + consumerGroup, err := sarama.NewConsumerGroup(*c.brokers, *c.consumerGroup, c.config) + if err != nil { + return err + } + c.ConsumerGroup = consumerGroup + + producer, err := c.getSyncProducer() + if err != nil { + return err + } + c.Producer = producer + + return nil +} + +// Kafka specific +type mskTokenProvider struct { + generateTokenTimeout time.Duration + accessKey string + secretKey string + sessionToken string + awsIamRoleArn string + awsStsSessionName string + region string +} + +func (m *mskTokenProvider) Token() (*sarama.AccessToken, error) { + // this function can't use the context passed on Init because that context would be cancelled right after Init + ctx, cancel := context.WithTimeout(context.Background(), m.generateTokenTimeout) + defer cancel() + + switch { + // we must first check if we are using the assume role auth profile + case m.awsIamRoleArn != "" && m.awsStsSessionName != "": + token, _, err := signer.GenerateAuthTokenFromRole(ctx, m.region, m.awsIamRoleArn, m.awsStsSessionName) + return &sarama.AccessToken{Token: token}, err + case m.accessKey != "" && m.secretKey != "": + token, _, err := signer.GenerateAuthTokenFromCredentialsProvider(ctx, m.region, aws2.CredentialsProviderFunc(func(ctx context.Context) (aws2.Credentials, error) { + return aws2.Credentials{ + AccessKeyID: m.accessKey, + SecretAccessKey: m.secretKey, + SessionToken: m.sessionToken, + }, nil + })) + return &sarama.AccessToken{Token: token}, err + + default: // load default aws creds + token, _, err := signer.GenerateAuthToken(ctx, m.region) + return &sarama.AccessToken{Token: token}, err + } +} + +func (c *KafkaClients) getSyncProducer() (sarama.SyncProducer, error) { + // Add SyncProducer specific properties to copy of base config + c.config.Producer.RequiredAcks = sarama.WaitForAll + c.config.Producer.Retry.Max = 5 + c.config.Producer.Return.Successes = true + + if *c.maxMessageBytes > 0 { + c.config.Producer.MaxMessageBytes = *c.maxMessageBytes + } + + saramaClient, err := sarama.NewClient(*c.brokers, c.config) + if err != nil { + return nil, err + } + + producer, err := sarama.NewSyncProducerFromClient(saramaClient) + if err != nil { + return nil, err + } + + return producer, nil +} diff --git a/common/authentication/aws/client_fake.go b/common/authentication/aws/client_fake.go new file mode 100644 index 0000000000..c9e23641ba --- /dev/null +++ b/common/authentication/aws/client_fake.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" +) + +type MockParameterStore struct { + GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) + DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) + ssmiface.SSMAPI +} + +func (m *MockParameterStore) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return m.GetParameterFn(ctx, input, option...) +} + +func (m *MockParameterStore) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { + return m.DescribeParametersFn(ctx, input, option...) +} + +type MockSecretManager struct { + GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) + secretsmanageriface.SecretsManagerAPI +} + +func (m *MockSecretManager) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return m.GetSecretValueFn(ctx, input, option...) +} + +type MockDynamoDB struct { + GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) + PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) + DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) + BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) + TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) + dynamodbiface.DynamoDBAPI +} + +func (m *MockDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return m.GetItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { + return m.PutItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { + return m.DeleteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { + return m.BatchWriteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { + return m.TransactWriteItemsWithContextFn(ctx, input, op...) +} diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go new file mode 100644 index 0000000000..67d2ac88f3 --- /dev/null +++ b/common/authentication/aws/client_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type mockedSQS struct { + sqsiface.SQSAPI + GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) +} + +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck + return m.GetQueueURLFn(ctx, input) +} + +type mockedKinesis struct { + kinesisiface.KinesisAPI + DescribeStreamFn func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) +} + +func (m *mockedKinesis) DescribeStreamWithContext(ctx context.Context, input *kinesis.DescribeStreamInput, opts ...request.Option) (*kinesis.DescribeStreamOutput, error) { + return m.DescribeStreamFn(ctx, input) +} + +func TestS3Clients_New(t *testing.T) { + tests := []struct { + name string + s3Client *S3Clients + session *session.Session + }{ + {"initializes S3 client", &S3Clients{}, session.Must(session.NewSession())}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s3Client.New(tt.session) + require.NotNil(t, tt.s3Client.S3) + require.NotNil(t, tt.s3Client.Uploader) + require.NotNil(t, tt.s3Client.Downloader) + }) + } +} + +func TestSqsClients_QueueURL(t *testing.T) { + tests := []struct { + name string + mockFn func() *mockedSQS + queueName string + expectedURL *string + expectError bool + }{ + { + name: "returns queue URL successfully", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return &sqs.GetQueueUrlOutput{ + QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"), + }, nil + }, + } + }, + queueName: "valid-queue", + expectedURL: aws.String("https://sqs.aws.com/123456789012/queue"), + expectError: false, + }, + { + name: "returns error when queue URL not found", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return nil, errors.New("unable to get stream arn due to empty client") + }, + } + }, + queueName: "missing-queue", + expectedURL: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSQS := tt.mockFn() + + // Initialize SqsClients with the mocked SQS client + client := &SqsClients{ + Sqs: mockSQS, + } + + url, err := client.QueueURL(context.Background(), tt.queueName) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedURL, url) + } + }) + } +} + +func TestKinesisClients_Stream(t *testing.T) { + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + mockStreamARN *string + mockError error + expectedStream *string + expectedErr error + }{ + { + name: "successfully retrieves stream ARN", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + }, + }, nil + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "some-stream", + expectedStream: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + expectedErr: nil, + }, + { + name: "returns error when stream not found", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return nil, errors.New("stream not found") + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "nonexistent-stream", + expectedStream: nil, + expectedErr: errors.New("unable to get stream arn due to empty client"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kinesisClient.Stream(context.Background(), tt.streamName) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedStream, got) + } + }) + } +} + +func TestKinesisClients_WorkerCfg(t *testing.T) { + testCreds := credentials.NewStaticCredentials("accessKey", "secretKey", "") + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + consumer string + mode string + expectedConfig *config.KinesisClientLibConfiguration + }{ + { + name: "successfully creates shared mode worker config", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: config.NewKinesisClientLibConfigWithCredential( + "consumer1", "existing-stream", "us-west-1", "consumer1", testCreds, + ), + }, + { + name: "returns nil when mode is not shared", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "exclusive", + expectedConfig: nil, + }, + { + name: "returns nil when client is nil", + kinesisClient: &KinesisClients{ + Kinesis: nil, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.kinesisClient.WorkerCfg(context.Background(), tt.streamName, tt.consumer, tt.mode) + if tt.expectedConfig == nil { + assert.Equal(t, tt.expectedConfig, cfg) + return + } + assert.Equal(t, tt.expectedConfig.StreamName, cfg.StreamName) + assert.Equal(t, tt.expectedConfig.EnhancedFanOutConsumerName, cfg.EnhancedFanOutConsumerName) + assert.Equal(t, tt.expectedConfig.EnableEnhancedFanOutConsumer, cfg.EnableEnhancedFanOutConsumer) + assert.Equal(t, tt.expectedConfig.RegionName, cfg.RegionName) + }) + } +} diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go new file mode 100644 index 0000000000..e79dee1841 --- /dev/null +++ b/common/authentication/aws/static.go @@ -0,0 +1,431 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "time" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/dapr/kit/logger" +) + +type StaticAuth struct { + mu sync.RWMutex + logger logger.Logger + + region *string + endpoint *string + accessKey *string + secretKey *string + sessionToken string + + assumeRoleARN *string + sessionName string + + session *session.Session + cfg *aws.Config + clients *Clients +} + +func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { + auth := &StaticAuth{ + logger: opts.Logger, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + } + + if opts.Region != "" { + auth.region = &opts.Region + } + if opts.Endpoint != "" { + auth.endpoint = &opts.Endpoint + } + if opts.AccessKey != "" { + auth.accessKey = &opts.AccessKey + } + if opts.SecretKey != "" { + auth.secretKey = &opts.SecretKey + } + if opts.SessionToken != "" { + auth.sessionToken = opts.SessionToken + } + if opts.AssumeRoleARN != "" { + auth.assumeRoleARN = &opts.AssumeRoleARN + } + if opts.SessionName != "" { + auth.sessionName = opts.SessionName + } + + initialSession, err := auth.createSession() + if err != nil { + return nil, fmt.Errorf("failed to get token client: %v", err) + } + + auth.session = initialSession + + return auth, nil +} + +// This is to be used only for test purposes to inject mocked clients +func (a *StaticAuth) WithMockClients(clients *Clients) { + a.clients = clients +} + +func (a *StaticAuth) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *StaticAuth) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *StaticAuth) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *StaticAuth) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *StaticAuth) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *StaticAuth) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *StaticAuth) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *StaticAuth) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *StaticAuth) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Config) { + a.mu.Lock() + defer a.mu.Unlock() + + // Set max connection lifetime to 8 minutes in postgres connection pool configuration. + // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, + // while leveraging the BeforeConnect hook to recreate the token in time dynamically. + poolConfig.MaxConnLifetime = time.Minute * 8 + + // Setup connection pool config needed for AWS IAM authentication + poolConfig.BeforeConnect = func(ctx context.Context, pgConfig *pgx.ConnConfig) error { + // Manually reset auth token with aws and reset the config password using the new iam token + pwd, err := a.getDatabaseToken(ctx, poolConfig) + if err != nil { + return fmt.Errorf("failed to get database token: %w", err) + } + pgConfig.Password = pwd + poolConfig.ConnConfig.Password = pwd + + return nil + } +} + +// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.Go.html +func (a *StaticAuth) getDatabaseToken(ctx context.Context, poolConfig *pgxpool.Config) (string, error) { + dbEndpoint := poolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(poolConfig.ConnConfig.Port)) + + // First, check if there are credentials set explicitly with accesskey and secretkey + if a.accessKey != nil && a.secretKey != nil { + awsCfg := v2creds.NewStaticCredentialsProvider(*a.accessKey, *a.secretKey, a.sessionToken) + authenticationToken, err := auth.BuildAuthToken( + ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + + return authenticationToken, nil + } + + // Second, check if we are assuming a role instead + if a.assumeRoleARN != nil { + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err) + } + stsClient := sts.NewFromConfig(awsCfg) + + assumeRoleCfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(*a.region), + config.WithCredentialsProvider( + awsv2.NewCredentialsCache( + stscreds.NewAssumeRoleProvider(stsClient, *a.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) { + if a.sessionName != "" { + aro.RoleSessionName = a.sessionName + } + }), + ), + ), + ) + if err != nil { + return "", fmt.Errorf("failed to assume aws role %w", err) + } + + authenticationToken, err := auth.BuildAuthToken( + ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, assumeRoleCfg.Credentials) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + return authenticationToken, nil + } + + // Lastly, and by default, just use the default aws configuration + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err) + } + + authenticationToken, err := auth.BuildAuthToken(ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg.Credentials) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + + return authenticationToken, nil +} + +func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // This means we've already set the config in our New function + // to use the SASL token provider. + if a.clients.kafka != nil { + return a.clients.kafka, nil + } + + a.clients.kafka = initKafkaClients(opts) + // static auth has additional fields we need added, + // so we add those static auth specific fields here, + // and the rest of the token provider fields are added in New() + tokenProvider := mskTokenProvider{} + if a.assumeRoleARN != nil { + tokenProvider.awsIamRoleArn = *a.assumeRoleARN + } + if a.sessionName != "" { + tokenProvider.awsStsSessionName = a.sessionName + } + + err := a.clients.kafka.New(a.session, &tokenProvider) + if err != nil { + return nil, fmt.Errorf("failed to create AWS IAM Kafka config: %w", err) + } + + return a.clients.kafka, nil +} + +func (a *StaticAuth) createSession() (*session.Session, error) { + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig() + } else { + awsConfig = a.cfg + } + + if a.region != nil { + awsConfig = awsConfig.WithRegion(*a.region) + } + + if a.accessKey != nil && a.secretKey != nil { + // session token is an option field + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, a.sessionToken)) + } + + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) + } + + // TODO support assume role for all aws components + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *StaticAuth) Close() error { + a.mu.Lock() + defer a.mu.Unlock() + + errs := make([]error, 2) + if a.clients.kafka != nil { + if a.clients.kafka.Producer != nil { + errs[0] = a.clients.kafka.Producer.Close() + a.clients.kafka.Producer = nil + } + if a.clients.kafka.ConsumerGroup != nil { + errs[1] = a.clients.kafka.ConsumerGroup.Close() + a.clients.kafka.ConsumerGroup = nil + } + } + return errors.Join(errs...) +} + +func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { + optFns := []func(*config.LoadOptions) error{} + if region != "" { + optFns = append(optFns, config.WithRegion(region)) + } + + if accessKey != "" && secretKey != "" { + provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) + optFns = append(optFns, config.WithCredentialsProvider(provider)) + } + + awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) + if err != nil { + return awsv2.Config{}, err + } + + if endpoint != "" { + awsCfg.BaseEndpoint = &endpoint + } + + return awsCfg, nil +} diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go new file mode 100644 index 0000000000..8ceb5639e4 --- /dev/null +++ b/common/authentication/aws/static_test.go @@ -0,0 +1,72 @@ +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigV2(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + sessionToken string + region string + endpoint string + }{ + { + name: "valid config", + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + awsCfg, err := GetConfigV2(tt.accessKey, tt.secretKey, tt.sessionToken, tt.region, tt.endpoint) + require.NoError(t, err) + assert.NotNil(t, awsCfg) + assert.Equal(t, tt.region, awsCfg.Region) + assert.Equal(t, tt.endpoint, *awsCfg.BaseEndpoint) + }) + } +} + +func TestGetTokenClient(t *testing.T) { + tests := []struct { + name string + awsInstance *StaticAuth + }{ + { + name: "valid token client", + awsInstance: &StaticAuth{ + accessKey: aws.String("testAccessKey"), + secretKey: aws.String("testSecretKey"), + sessionToken: "testSessionToken", + region: aws.String("us-west-2"), + endpoint: aws.String("https://test.endpoint.com"), + }, + }, + { + name: "creds from environment", + awsInstance: &StaticAuth{ + region: aws.String("us-west-2"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session, err := tt.awsInstance.createSession() + require.NotNil(t, session) + require.NoError(t, err) + assert.Equal(t, tt.awsInstance.region, session.Config.Region) + }) + } +} diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go new file mode 100644 index 0000000000..6556ece74c --- /dev/null +++ b/common/authentication/aws/x509.go @@ -0,0 +1,607 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + cryptoX509 "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "strconv" + "sync" + "time" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" +) + +func isX509Auth(m map[string]string) bool { + tp := m["trustProfileArn"] + ta := m["trustAnchorArn"] + ar := m["assumeRoleArn"] + return tp != "" && ta != "" && ar != "" +} + +type x509Options struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} + +type x509 struct { + mu sync.RWMutex + wg sync.WaitGroup + closeCh chan struct{} + + logger logger.Logger + clients *Clients + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests + session *session.Session + cfg *aws.Config + + chainPEM []byte + keyPEM []byte + + region *string + trustProfileArn *string + trustAnchorArn *string + assumeRoleArn *string + sessionName string +} + +func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { + var x509Auth x509Options + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return nil, err + } + + switch { + case x509Auth.TrustProfileArn == nil: + return nil, errors.New("trustProfileArn is required") + case x509Auth.TrustAnchorArn == nil: + return nil, errors.New("trustAnchorArn is required") + case x509Auth.AssumeRoleArn == nil: + return nil, errors.New("assumeRoleArn is required") + } + + auth := &x509{ + logger: opts.Logger, + trustProfileArn: x509Auth.TrustProfileArn, + trustAnchorArn: x509Auth.TrustAnchorArn, + assumeRoleArn: x509Auth.AssumeRoleArn, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + closeCh: make(chan struct{}), + } + + if err := auth.getCertPEM(ctx); err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + // Parse trust anchor and profile ARNs + if err := auth.initializeTrustAnchors(); err != nil { + return nil, err + } + + initialSession, err := auth.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create the initial session: %v", err) + } + auth.session = initialSession + auth.startSessionRefresher() + + return auth, nil +} + +func (a *x509) Close() error { + a.mu.Lock() + defer a.mu.Unlock() + close(a.closeCh) + a.wg.Wait() + + errs := make([]error, 2) + if a.clients.kafka != nil { + if a.clients.kafka.Producer != nil { + errs[0] = a.clients.kafka.Producer.Close() + a.clients.kafka.Producer = nil + } + if a.clients.kafka.ConsumerGroup != nil { + errs[1] = a.clients.kafka.ConsumerGroup.Close() + a.clients.kafka.ConsumerGroup = nil + } + } + return errors.Join(errs...) +} + +func (a *x509) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return errors.New("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.chainPEM = chainPEM + a.keyPEM = keyPEM + return nil +} + +func (a *x509) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *x509) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *x509) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *x509) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *x509) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *x509) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *x509) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *x509) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *x509) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.Go.html +func (a *x509) getDatabaseToken(ctx context.Context, poolConfig *pgxpool.Config) (string, error) { + dbEndpoint := poolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(poolConfig.ConnConfig.Port)) + + // First, check session credentials. + // This should always be what we use to generate the x509 auth credentials for postgres. + // However, we can leave the Second and Lastly checks as backup for now. + var creds credentials.Value + if a.session != nil { + var err error + creds, err = a.session.Config.Credentials.Get() + if err != nil { + a.logger.Infof("failed to get access key and secret key, will fallback to reading the default AWS credentials file: %w", err) + } + } + + if creds.AccessKeyID != "" && creds.SecretAccessKey != "" { + creds, err := a.session.Config.Credentials.Get() + if err != nil { + return "", fmt.Errorf("failed to retrieve session credentials: %w", err) + } + awsCfg := v2creds.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken) + authenticationToken, err := auth.BuildAuthToken( + ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + + return authenticationToken, nil + } + + // Second, check if we are assuming a role instead + if a.assumeRoleArn != nil { + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err) + } + stsClient := sts.NewFromConfig(awsCfg) + + assumeRoleCfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(*a.region), + config.WithCredentialsProvider( + awsv2.NewCredentialsCache( + stscreds.NewAssumeRoleProvider(stsClient, *a.assumeRoleArn, func(aro *stscreds.AssumeRoleOptions) { + if a.sessionName != "" { + aro.RoleSessionName = a.sessionName + } + }), + ), + ), + ) + if err != nil { + return "", fmt.Errorf("failed to assume aws role %w", err) + } + + authenticationToken, err := auth.BuildAuthToken( + ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, assumeRoleCfg.Credentials) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + return authenticationToken, nil + } + + // Lastly, and by default, just use the default aws configuration + awsCfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load default AWS authentication configuration %w", err) + } + + authenticationToken, err := auth.BuildAuthToken( + ctx, dbEndpoint, *a.region, poolConfig.ConnConfig.User, awsCfg.Credentials) + if err != nil { + return "", fmt.Errorf("failed to create AWS authentication token: %w", err) + } + + return authenticationToken, nil +} + +func (a *x509) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Config) { + a.mu.Lock() + defer a.mu.Unlock() + + // Set max connection lifetime to 8 minutes in postgres connection pool configuration. + // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, + // while leveraging the BeforeConnect hook to recreate the token in time dynamically. + poolConfig.MaxConnLifetime = time.Minute * 8 + + // Setup connection pool config needed for AWS IAM authentication + poolConfig.BeforeConnect = func(ctx context.Context, pgConfig *pgx.ConnConfig) error { + // Manually reset auth token with aws and reset the config password using the new iam token + pwd, err := a.getDatabaseToken(ctx, poolConfig) + if err != nil { + return fmt.Errorf("failed to get database token: %w", err) + } + pgConfig.Password = pwd + poolConfig.ConnConfig.Password = pwd + + return nil + } +} + +func (a *x509) Kafka(opts KafkaOptions) (*KafkaClients, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // This means we've already set the config in our New function + // to use the SASL token provider. + if a.clients.kafka != nil { + return a.clients.kafka, nil + } + + a.clients.kafka = initKafkaClients(opts) + // Note: we pass in nil for token provider, + // as there are no special fields for x509 auth for it. + err := a.clients.kafka.New(a.session, nil) + if err != nil { + return nil, fmt.Errorf("failed to create AWS IAM Kafka config: %w", err) + } + return a.clients.kafka, nil +} + +func (a *x509) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.trustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.trustAnchorArn) + if err != nil { + return err + } + a.region = &trustAnchor.Region + } + + if a.trustProfileArn != nil { + profile, err = arn.Parse(*a.trustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *x509) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) + if err != nil { + return err + } + + ints := make([]cryptoX509.Certificate, 0, len(certs)-1) + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, error) { + a.mu.Lock() + defer a.mu.Unlock() + + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + var mySession *session.Session + + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) + } else { + awsConfig = a.cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + } + if a.region != nil { + awsConfig.WithRegion(*a.region) + } + // this is needed for testing purposes to mock the client, + // so code never sets the client, but tests do. + var rolesClient *rolesanywhere.RolesAnywhere + if a.rolesAnywhereClient == nil { + mySession = session.Must(session.NewSession(awsConfig)) + rolesAnywhereClient := rolesanywhere.New(mySession, awsConfig) + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + rolesClient = rolesAnywhereClient + } + + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.trustProfileArn, + TrustAnchorArn: a.trustAnchorArn, + RoleArn: a.assumeRoleArn, + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + DurationSeconds: aws.Int64(int64(time.Hour.Seconds())), // AWS default is 1hr timeout + InstanceProperties: nil, + SessionName: nil, + } + + var output *rolesanywhere.CreateSessionOutput + if a.rolesAnywhereClient != nil { + var err error + output, err = a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } else { + var err error + output, err = rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } + + if output == nil || len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + accessKey := output.CredentialSet[0].Credentials.AccessKeyId + secretKey := output.CredentialSet[0].Credentials.SecretAccessKey + sessionToken := output.CredentialSet[0].Credentials.SessionToken + awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) + sess := session.Must(session.NewSession(&aws.Config{ + Credentials: awsCreds, + }, awsConfig)) + if sess == nil { + return nil, errors.New("session is nil") + } + + return sess, nil +} + +func (a *x509) startSessionRefresher() { + a.logger.Infof("starting session refresher for x509 auth") + + a.wg.Add(1) + go func() { + defer a.wg.Done() + for { + // renew at ~half the lifespan + expiration, err := a.session.Config.Credentials.ExpiresAt() + if err != nil { + a.logger.Errorf("Failed to retrieve session expiration time, using 30 minute interval: %w", err) + expiration = time.Now().Add(time.Hour) + } + timeUntilExpiration := time.Until(expiration) + refreshInterval := timeUntilExpiration / 2 + select { + case <-time.After(refreshInterval): + a.refreshClient() + case <-a.closeCh: + a.logger.Debugf("Session refresher is stopped") + return + } + } + }() +} + +func (a *x509) refreshClient() { + for { + newSession, err := a.createOrRefreshSession(context.Background()) + if err == nil { + err = a.clients.refresh(newSession) + if err != nil { + a.logger.Errorf("Failed to refresh client, retrying in 5 seconds: %w", err) + } + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + return + } + a.logger.Errorf("Failed to refresh session, retrying in 5 seconds: %w", err) + select { + case <-time.After(time.Second * 5): + case <-a.closeCh: + return + } + } +} diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go new file mode 100644 index 0000000000..3f7d2189c3 --- /dev/null +++ b/common/authentication/aws/x509_test.go @@ -0,0 +1,125 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + cryptoX509 "crypto/x509" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dapr/kit/crypto/spiffe" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/crypto/test" + "github.com/dapr/kit/logger" +) + +type mockRolesAnywhereClient struct { + rolesanywhereiface.RolesAnywhereAPI + + CreateSessionOutput *rolesanywhere.CreateSessionOutput + CreateSessionError error +} + +func (m *mockRolesAnywhereClient) CreateSessionWithContext(ctx context.Context, input *rolesanywhere.CreateSessionInput, opts ...request.Option) (*rolesanywhere.CreateSessionOutput, error) { + return m.CreateSessionOutput, m.CreateSessionError +} + +func TestGetX509Client(t *testing.T) { + tests := []struct { + name string + mockOutput *rolesanywhere.CreateSessionOutput + mockError error + }{ + { + name: "valid x509 client", + mockOutput: &rolesanywhere.CreateSessionOutput{ + CredentialSet: []*rolesanywhere.CredentialResponse{ + { + Credentials: &rolesanywhere.Credentials{ + AccessKeyId: aws.String("mockAccessKeyId"), + SecretAccessKey: aws.String("mockSecretAccessKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.String(time.Now().Add(15 * time.Minute).Format(time.RFC3339)), + }, + }, + }, + }, + mockError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := &mockRolesAnywhereClient{ + CreateSessionOutput: tt.mockOutput, + CreateSessionError: tt.mockError, + } + mockAWS := x509{ + logger: logger.NewLogger("testLogger"), + assumeRoleArn: aws.String("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + trustAnchorArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + trustProfileArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), + rolesAnywhereClient: mockSvc, + } + pki := test.GenPKI(t, test.PKIOptions{ + LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"), + }) + + respCert := []*cryptoX509.Certificate{pki.LeafCert} + var respErr error + + var fetches atomic.Int32 + s := spiffe.New(spiffe.Options{ + Log: logger.NewLogger("test"), + RequestSVIDFn: func(context.Context, []byte) ([]*cryptoX509.Certificate, error) { + fetches.Add(1) + return respCert, respErr + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + go func() { + errCh <- s.Run(ctx) + }() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + err := s.Ready(ctx) + require.NoError(t, err) + + // inject the SVID source into the context + ctx = spiffecontext.With(ctx, s) + session, err := mockAWS.createOrRefreshSession(ctx) + + require.NoError(t, err) + assert.NotNil(t, session) + }) + } +} diff --git a/common/authentication/postgresql/metadata.go b/common/authentication/postgresql/metadata.go index 7cacecfaa4..7a197dffc3 100644 --- a/common/authentication/postgresql/metadata.go +++ b/common/authentication/postgresql/metadata.go @@ -26,6 +26,7 @@ import ( "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/common/authentication/azure" "github.com/dapr/components-contrib/metadata" + "github.com/dapr/kit/logger" ) // PostgresAuthMetadata contains authentication metadata for PostgreSQL components. @@ -86,16 +87,47 @@ func (m *PostgresAuthMetadata) InitWithMetadata(meta map[string]string, opts Ini return nil } -func (m *PostgresAuthMetadata) ValidateAwsIamFields() (string, string, string, error) { +func (m *PostgresAuthMetadata) BuildAwsIamOptions(logger logger.Logger, properties map[string]string) (*aws.Options, error) { awsRegion, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "AWSRegion") - if awsRegion == "" { - return "", "", "", errors.New("metadata property AWSRegion is missing") + region, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "region") + if region == "" { + region = awsRegion } + if region == "" { + return nil, errors.New("metadata properties 'region' or 'AWSRegion' is missing") + } + // Note: access key and secret keys can be optional // in the event users are leveraging the credential files for an access token. awsAccessKey, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "AWSAccessKey") + // This is needed as we remove the awsAccessKey field to use the builtin AWS profile 'accessKey' field instead. + accessKey, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "AccessKey") + if awsAccessKey == "" || accessKey != "" { + awsAccessKey = accessKey + } awsSecretKey, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "AWSSecretKey") - return awsRegion, awsAccessKey, awsSecretKey, nil + // This is needed as we remove the awsSecretKey field to use the builtin AWS profile 'secretKey' field instead. + secretKey, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "SecretKey") + if awsSecretKey == "" || secretKey != "" { + awsSecretKey = secretKey + } + sessionToken, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "sessionToken") + assumeRoleArn, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "assumeRoleArn") + sessionName, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "sessionName") + if sessionName == "" { + sessionName = "DaprDefaultSession" + } + return &aws.Options{ + Region: region, + AccessKey: awsAccessKey, + SecretKey: awsSecretKey, + SessionToken: sessionToken, + AssumeRoleARN: assumeRoleArn, + SessionName: sessionName, + + Logger: logger, + Properties: properties, + }, nil } // GetPgxPoolConfig returns the pgxpool.Config object that contains the credentials for connecting to PostgreSQL. @@ -154,27 +186,6 @@ func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) { cc.Password = at.Token return nil } - case m.UseAWSIAM: - // We should use AWS IAM - awsRegion, awsAccessKey, awsSecretKey, err := m.ValidateAwsIamFields() - if err != nil { - err = fmt.Errorf("failed to validate AWS IAM authentication fields: %w", err) - return nil, err - } - - awsOpts := aws.AWSIAMAuthOptions{ - PoolConfig: config, - ConnectionString: m.ConnectionString, - Region: awsRegion, - AccessKey: awsAccessKey, - SecretKey: awsSecretKey, - } - - err = awsOpts.InitiateAWSIAMAuth() - if err != nil { - err = fmt.Errorf("failed to initiate AWS IAM authentication rotation: %w", err) - return nil, err - } } return config, nil diff --git a/common/component/azure/eventhubs/eventhubs.go b/common/component/azure/eventhubs/eventhubs.go index f5724e891f..5c0bd825db 100644 --- a/common/component/azure/eventhubs/eventhubs.go +++ b/common/component/azure/eventhubs/eventhubs.go @@ -127,6 +127,11 @@ func (aeh *AzureEventHubs) EventHubName() string { return aeh.metadata.hubName } +// GetAllMessageProperties returns a boolean to indicate whether to return all properties for an event hubs message. +func (aeh *AzureEventHubs) GetAllMessageProperties() bool { + return aeh.metadata.GetAllMessageProperties +} + // Publish a batch of messages. func (aeh *AzureEventHubs) Publish(ctx context.Context, topic string, messages []*azeventhubs.EventData, batchOpts *azeventhubs.EventDataBatchOptions) error { // Get the producer client @@ -165,7 +170,7 @@ func (aeh *AzureEventHubs) GetBindingsHandlerFunc(topic string, getAllProperties return nil, fmt.Errorf("expected 1 message, got %d", len(messages)) } - bindingsMsg, err := NewBindingsReadResponseFromEventData(messages[0], topic, getAllProperties) + bindingsMsg, err := NewBindingsReadResponseFromEventData(messages[0], topic, aeh.GetAllMessageProperties()) if err != nil { return nil, fmt.Errorf("failed to get bindings read response from azure eventhubs message: %w", err) } @@ -242,12 +247,6 @@ func (aeh *AzureEventHubs) Subscribe(subscribeCtx context.Context, config Subscr } topic := config.Topic - // Get the processor client - processor, err := aeh.getProcessorForTopic(subscribeCtx, topic) - if err != nil { - return fmt.Errorf("error trying to establish a connection: %w", err) - } - // This component has built-in retries because Event Hubs doesn't support N/ACK for messages retryHandler := func(ctx context.Context, events []*azeventhubs.ReceivedEventData) ([]HandlerResponseItem, error) { b := aeh.backOffConfig.NewBackOffWithContext(ctx) @@ -277,51 +276,58 @@ func (aeh *AzureEventHubs) Subscribe(subscribeCtx context.Context, config Subscr subscriptionLoopFinished := make(chan bool, 1) - // Process all partition clients as they come in - subscriberLoop := func() { - for { - // This will block until a new partition client is available - // It returns nil if processor.Run terminates or if the context is canceled - partitionClient := processor.NextPartitionClient(subscribeCtx) - if partitionClient == nil { - subscriptionLoopFinished <- true - return - } - aeh.logger.Debugf("Received client for partition %s", partitionClient.PartitionID()) - - // Once we get a partition client, process the events in a separate goroutine - go func() { - processErr := aeh.processEvents(subscribeCtx, partitionClient, retryConfig) - // Do not log context.Canceled which happens at shutdown - if processErr != nil && !errors.Is(processErr, context.Canceled) { - aeh.logger.Errorf("Error processing events from partition client: %v", processErr) - } - }() - } - } - - // Start the processor + // Start the subscribe + processor loop go func() { for { - go subscriberLoop() - // This is a blocking call that runs until the context is canceled - err = processor.Run(subscribeCtx) - // Exit if the context is canceled - if err != nil && errors.Is(err, context.Canceled) { - return - } + // Get the processor client + processor, err := aeh.getProcessorForTopic(subscribeCtx, topic) if err != nil { - aeh.logger.Errorf("Error from event processor: %v", err) + aeh.logger.Errorf("error trying to establish a connection: %w", err) } else { - aeh.logger.Debugf("Event processor terminated without error") - } - // wait for subscription loop finished signal - select { - case <-subscribeCtx.Done(): - return - case <-subscriptionLoopFinished: - // noop + // Process all partition clients as they come in + subscriberLoop := func() { + for { + // This will block until a new partition client is available + // It returns nil if processor.Run terminates or if the context is canceled + partitionClient := processor.NextPartitionClient(subscribeCtx) + if partitionClient == nil { + subscriptionLoopFinished <- true + return + } + aeh.logger.Debugf("Received client for partition %s", partitionClient.PartitionID()) + + // Once we get a partition client, process the events in a separate goroutine + go func() { + processErr := aeh.processEvents(subscribeCtx, partitionClient, retryConfig) + // Do not log context.Canceled which happens at shutdown + if processErr != nil && !errors.Is(processErr, context.Canceled) { + aeh.logger.Errorf("Error processing events from partition client: %v", processErr) + } + }() + } + } + + go subscriberLoop() + // This is a blocking call that runs until the context is canceled or a non-recoverable error is returned. + err = processor.Run(subscribeCtx) + // Exit if the context is canceled + if err != nil && errors.Is(err, context.Canceled) { + return + } + if err != nil { + aeh.logger.Errorf("Error from event processor: %v", err) + } else { + aeh.logger.Debugf("Event processor terminated without error") + } + // wait for subscription loop finished signal + select { + case <-subscribeCtx.Done(): + return + case <-subscriptionLoopFinished: + // noop + } } + // Waiting here is not strictly necessary, however, we will wait for a short time to increase the likelihood of transient errors having disappeared select { case <-subscribeCtx.Done(): @@ -393,7 +399,11 @@ func (aeh *AzureEventHubs) processEvents(subscribeCtx context.Context, partition if len(events) != 0 { // Handle received message - go aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + if aeh.metadata.EnableInOrderMessageDelivery { + aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + } else { + go aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + } // Checkpointing disabled for CheckPointFrequencyPerPartition == 0 if config.CheckPointFrequencyPerPartition > 0 { diff --git a/common/component/azure/eventhubs/eventhubs_test.go b/common/component/azure/eventhubs/eventhubs_test.go index bf63c19b6b..8381c07743 100644 --- a/common/component/azure/eventhubs/eventhubs_test.go +++ b/common/component/azure/eventhubs/eventhubs_test.go @@ -72,6 +72,18 @@ func TestParseEventHubsMetadata(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "one of connectionString or eventHubNamespace is required") }) + + t.Run("test in order delivery", func(t *testing.T) { + metadata := map[string]string{ + "enableInOrderMessageDelivery": "true", + "connectionString": "fake", + } + + m, err := parseEventHubsMetadata(metadata, false, testLogger) + + require.NoError(t, err) + require.True(t, m.EnableInOrderMessageDelivery) + }) } func TestConstructConnectionStringFromTopic(t *testing.T) { diff --git a/common/component/azure/eventhubs/metadata.go b/common/component/azure/eventhubs/metadata.go index b5e94e114e..1b003b3d34 100644 --- a/common/component/azure/eventhubs/metadata.go +++ b/common/component/azure/eventhubs/metadata.go @@ -26,18 +26,20 @@ import ( ) type AzureEventHubsMetadata struct { - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - EventHubNamespace string `json:"eventHubNamespace" mapstructure:"eventHubNamespace"` - ConsumerID string `json:"consumerID" mapstructure:"consumerID"` - StorageConnectionString string `json:"storageConnectionString" mapstructure:"storageConnectionString"` - StorageAccountName string `json:"storageAccountName" mapstructure:"storageAccountName"` - StorageAccountKey string `json:"storageAccountKey" mapstructure:"storageAccountKey"` - StorageContainerName string `json:"storageContainerName" mapstructure:"storageContainerName"` - EnableEntityManagement bool `json:"enableEntityManagement,string" mapstructure:"enableEntityManagement"` - MessageRetentionInDays int32 `json:"messageRetentionInDays,string" mapstructure:"messageRetentionInDays"` - PartitionCount int32 `json:"partitionCount,string" mapstructure:"partitionCount"` - SubscriptionID string `json:"subscriptionID" mapstructure:"subscriptionID"` - ResourceGroupName string `json:"resourceGroupName" mapstructure:"resourceGroupName"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + EventHubNamespace string `json:"eventHubNamespace" mapstructure:"eventHubNamespace"` + ConsumerID string `json:"consumerID" mapstructure:"consumerID"` + StorageConnectionString string `json:"storageConnectionString" mapstructure:"storageConnectionString"` + StorageAccountName string `json:"storageAccountName" mapstructure:"storageAccountName"` + StorageAccountKey string `json:"storageAccountKey" mapstructure:"storageAccountKey"` + StorageContainerName string `json:"storageContainerName" mapstructure:"storageContainerName"` + EnableEntityManagement bool `json:"enableEntityManagement,string" mapstructure:"enableEntityManagement"` + MessageRetentionInDays int32 `json:"messageRetentionInDays,string" mapstructure:"messageRetentionInDays"` + PartitionCount int32 `json:"partitionCount,string" mapstructure:"partitionCount"` + SubscriptionID string `json:"subscriptionID" mapstructure:"subscriptionID"` + ResourceGroupName string `json:"resourceGroupName" mapstructure:"resourceGroupName"` + EnableInOrderMessageDelivery bool `json:"enableInOrderMessageDelivery,string" mapstructure:"enableInOrderMessageDelivery"` + GetAllMessageProperties bool `json:"getAllMessageProperties,string" mapstructure:"getAllMessageProperties"` // Binding only EventHub string `json:"eventHub" mapstructure:"eventHub" mdonly:"bindings"` diff --git a/common/component/kafka/auth.go b/common/component/kafka/auth.go index bd61690c05..ea8cc43fac 100644 --- a/common/component/kafka/auth.go +++ b/common/component/kafka/auth.go @@ -14,16 +14,12 @@ limitations under the License. package kafka import ( - "context" "crypto/tls" "crypto/x509" "errors" "fmt" - "time" "github.com/IBM/sarama" - "github.com/aws/aws-msk-iam-sasl-signer-go/signer" - aws2 "github.com/aws/aws-sdk-go-v2/aws" ) func updatePasswordAuthInfo(config *sarama.Config, metadata *KafkaMetadata, saslUsername, saslPassword string) { @@ -92,58 +88,3 @@ func updateOidcAuthInfo(config *sarama.Config, metadata *KafkaMetadata) error { return nil } - -func updateAWSIAMAuthInfo(ctx context.Context, config *sarama.Config, metadata *KafkaMetadata) error { - config.Net.SASL.Enable = true - config.Net.SASL.Mechanism = sarama.SASLTypeOAuth - config.Net.SASL.TokenProvider = &mskAccessTokenProvider{ - ctx: ctx, - generateTokenTimeout: 10 * time.Second, - region: metadata.AWSRegion, - accessKey: metadata.AWSAccessKey, - secretKey: metadata.AWSSecretKey, - sessionToken: metadata.AWSSessionToken, - awsIamRoleArn: metadata.AWSIamRoleArn, - awsStsSessionName: metadata.AWSStsSessionName, - } - - _, err := config.Net.SASL.TokenProvider.Token() - if err != nil { - return fmt.Errorf("error validating iam credentials %v", err) - } - return nil -} - -type mskAccessTokenProvider struct { - ctx context.Context - generateTokenTimeout time.Duration - accessKey string - secretKey string - sessionToken string - awsIamRoleArn string - awsStsSessionName string - region string -} - -func (m *mskAccessTokenProvider) Token() (*sarama.AccessToken, error) { - // this function can't use the context passed on Init because that context would be cancelled right after Init - ctx, cancel := context.WithTimeout(m.ctx, m.generateTokenTimeout) - defer cancel() - - if m.accessKey != "" && m.secretKey != "" { - token, _, err := signer.GenerateAuthTokenFromCredentialsProvider(ctx, m.region, aws2.CredentialsProviderFunc(func(ctx context.Context) (aws2.Credentials, error) { - return aws2.Credentials{ - AccessKeyID: m.accessKey, - SecretAccessKey: m.secretKey, - SessionToken: m.sessionToken, - }, nil - })) - return &sarama.AccessToken{Token: token}, err - } else if m.awsIamRoleArn != "" { - token, _, err := signer.GenerateAuthTokenFromRole(ctx, m.region, m.awsIamRoleArn, m.awsStsSessionName) - return &sarama.AccessToken{Token: token}, err - } - - token, _, err := signer.GenerateAuthToken(ctx, m.region) - return &sarama.AccessToken{Token: token}, err -} diff --git a/common/component/kafka/clients.go b/common/component/kafka/clients.go new file mode 100644 index 0000000000..8e8111b7b2 --- /dev/null +++ b/common/component/kafka/clients.go @@ -0,0 +1,64 @@ +package kafka + +import ( + "fmt" + + "github.com/IBM/sarama" + + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" +) + +type clients struct { + consumerGroup sarama.ConsumerGroup + producer sarama.SyncProducer +} + +func (k *Kafka) latestClients() (*clients, error) { + switch { + // case 0: use mock clients for testing + case k.mockProducer != nil || k.mockConsumerGroup != nil: + return &clients{ + consumerGroup: k.mockConsumerGroup, + producer: k.mockProducer, + }, nil + + // case 1: use aws clients with refreshable tokens in the cfg + case k.awsAuthProvider != nil: + awsKafkaOpts := awsAuth.KafkaOptions{ + Config: k.config, + ConsumerGroup: k.consumerGroup, + Brokers: k.brokers, + MaxMessageBytes: k.maxMessageBytes, + } + awsKafkaClients, err := k.awsAuthProvider.Kafka(awsKafkaOpts) + if err != nil { + return nil, fmt.Errorf("failed to get AWS IAM Kafka clients: %w", err) + } + return &clients{ + consumerGroup: awsKafkaClients.ConsumerGroup, + producer: awsKafkaClients.Producer, + }, nil + + // case 2: normal static auth profile clients + default: + if k.clients != nil { + return k.clients, nil + } + cg, err := sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config) + if err != nil { + return nil, err + } + + p, err := GetSyncProducer(*k.config, k.brokers, k.maxMessageBytes) + if err != nil { + return nil, err + } + + newStaticClients := clients{ + consumerGroup: cg, + producer: p, + } + k.clients = &newStaticClients + return k.clients, nil + } +} diff --git a/common/component/kafka/consumer.go b/common/component/kafka/consumer.go index 7cb923a455..b7ea51240c 100644 --- a/common/component/kafka/consumer.go +++ b/common/component/kafka/consumer.go @@ -14,6 +14,7 @@ limitations under the License. package kafka import ( + "context" "errors" "fmt" "net/url" @@ -32,6 +33,29 @@ type consumer struct { mutex sync.Mutex } +func notifyRecover(consumer *consumer, message *sarama.ConsumerMessage, session sarama.ConsumerGroupSession, b backoff.BackOff) error { + for { + if err := retry.NotifyRecover(func() error { + return consumer.doCallback(session, message) + }, b, func(err error, d time.Duration) { + consumer.k.logger.Warnf("Error processing Kafka message: %s/%d/%d [key=%s]. Error: %v. Retrying...", message.Topic, message.Partition, message.Offset, asBase64String(message.Key), err) + }, func() { + consumer.k.logger.Infof("Successfully processed Kafka message after it previously failed: %s/%d/%d [key=%s]", message.Topic, message.Partition, message.Offset, asBase64String(message.Key)) + }); err != nil { + // If the retry policy got interrupted, it could mean that either + // the policy has reached its maximum number of attempts or the context has been cancelled. + // There is a weird edge case where the error returned is a 'context canceled' error but the session.Context is not done. + // This is a workaround to handle that edge case and reprocess the current message. + if err == context.Canceled && session.Context().Err() == nil { + consumer.k.logger.Warnf("Error processing Kafka message: %s/%d/%d [key=%s]. The error returned is 'context canceled' but the session context is not done. Retrying...") + continue + } + return err + } + return nil + } +} + func (consumer *consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { b := consumer.k.backOffConfig.NewBackOffWithContext(session.Context()) isBulkSubscribe := consumer.k.checkBulkSubscribe(claim.Topic()) @@ -83,13 +107,7 @@ func (consumer *consumer) ConsumeClaim(session sarama.ConsumerGroupSession, clai } if consumer.k.consumeRetryEnabled { - if err := retry.NotifyRecover(func() error { - return consumer.doCallback(session, message) - }, b, func(err error, d time.Duration) { - consumer.k.logger.Warnf("Error processing Kafka message: %s/%d/%d [key=%s]. Error: %v. Retrying...", message.Topic, message.Partition, message.Offset, asBase64String(message.Key), err) - }, func() { - consumer.k.logger.Infof("Successfully processed Kafka message after it previously failed: %s/%d/%d [key=%s]", message.Topic, message.Partition, message.Offset, asBase64String(message.Key)) - }); err != nil { + if err := notifyRecover(consumer, message, session, b); err != nil { consumer.k.logger.Errorf("Too many failed attempts at processing Kafka message: %s/%d/%d [key=%s]. Error: %v.", message.Topic, message.Partition, message.Offset, asBase64String(message.Key), err) } } else { diff --git a/common/component/kafka/kafka.go b/common/component/kafka/kafka.go index 2f3b67be0d..3e930ec4b2 100644 --- a/common/component/kafka/kafka.go +++ b/common/component/kafka/kafka.go @@ -28,6 +28,7 @@ import ( "github.com/linkedin/goavro/v2" "github.com/riferrei/srclient" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" kitmd "github.com/dapr/kit/metadata" @@ -36,18 +37,23 @@ import ( // Kafka allows reading/writing to a Kafka consumer group. type Kafka struct { - producer sarama.SyncProducer - consumerGroup string - brokers []string - logger logger.Logger - authType string - saslUsername string - saslPassword string - initialOffset int64 - config *sarama.Config - escapeHeaders bool - - cg sarama.ConsumerGroup + // These are used to inject mocked clients for tests + mockConsumerGroup sarama.ConsumerGroup + mockProducer sarama.SyncProducer + clients *clients + + maxMessageBytes int + consumerGroup string + brokers []string + logger logger.Logger + authType string + saslUsername string + saslPassword string + initialOffset int64 + config *sarama.Config + escapeHeaders bool + awsAuthProvider awsAuth.Provider + subscribeTopics TopicHandlerConfig subscribeLock sync.Mutex consumerCancel context.CancelFunc @@ -182,19 +188,32 @@ func (k *Kafka) Init(ctx context.Context, metadata map[string]string) error { // already handled in updateTLSConfig case awsIAMAuthType: k.logger.Info("Configuring AWS IAM authentication") - err = updateAWSIAMAuthInfo(k.internalContext, config, meta) + kafkaIAM, validateErr := k.ValidateAWS(metadata) + if validateErr != nil { + return fmt.Errorf("failed to validate AWS IAM authentication fields: %w", validateErr) + } + opts := awsAuth.Options{ + Logger: k.logger, + Properties: metadata, + Region: kafkaIAM.Region, + Endpoint: "", + AccessKey: kafkaIAM.AccessKey, + SecretKey: kafkaIAM.SecretKey, + SessionToken: kafkaIAM.SessionToken, + AssumeRoleARN: kafkaIAM.IamRoleArn, + SessionName: kafkaIAM.StsSessionName, + } + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } + k.awsAuthProvider = provider } k.config = config sarama.Logger = SaramaLogBridge{daprLogger: k.logger} - - k.producer, err = getSyncProducer(*k.config, k.brokers, meta.MaxMessageBytes) - if err != nil { - return err - } + k.maxMessageBytes = meta.MaxMessageBytes // Default retry configuration is used if no // backOff properties are set. @@ -208,40 +227,70 @@ func (k *Kafka) Init(ctx context.Context, metadata map[string]string) error { k.consumeRetryInterval = meta.ConsumeRetryInterval if meta.SchemaRegistryURL != "" { + k.logger.Infof("Schema registry URL '%s' provided. Configuring the Schema Registry client.", meta.SchemaRegistryURL) k.srClient = srclient.CreateSchemaRegistryClient(meta.SchemaRegistryURL) // Empty password is a possibility if meta.SchemaRegistryAPIKey != "" { k.srClient.SetCredentials(meta.SchemaRegistryAPIKey, meta.SchemaRegistryAPISecret) } + k.logger.Infof("Schema caching enabled: %v", meta.SchemaCachingEnabled) k.srClient.CachingEnabled(meta.SchemaCachingEnabled) if meta.SchemaCachingEnabled { k.latestSchemaCache = make(map[string]SchemaCacheEntry) + k.logger.Debugf("Schema cache TTL: %v", meta.SchemaLatestVersionCacheTTL) k.latestSchemaCacheTTL = meta.SchemaLatestVersionCacheTTL } } - k.logger.Debug("Kafka message bus initialization complete") - k.cg, err = sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config) - if err != nil { - return err + clients, err := k.latestClients() + if err != nil || clients == nil { + return fmt.Errorf("failed to get latest Kafka clients for initialization: %w", err) + } + if clients.producer == nil { + return errors.New("component is closed") } + if clients.consumerGroup == nil { + return errors.New("component is closed") + } + + k.logger.Debug("Kafka message bus initialization complete") return nil } +func (k *Kafka) ValidateAWS(metadata map[string]string) (*awsAuth.DeprecatedKafkaIAM, error) { + const defaultSessionName = "DaprDefaultSession" + // This is needed as we remove the aws prefixed fields to use the builtin AWS profile fields instead. + region := awsAuth.Coalesce(metadata["region"], metadata["awsRegion"]) + accessKey := awsAuth.Coalesce(metadata["accessKey"], metadata["awsAccessKey"]) + secretKey := awsAuth.Coalesce(metadata["secretKey"], metadata["awsSecretKey"]) + role := awsAuth.Coalesce(metadata["assumeRoleArn"], metadata["awsIamRoleArn"]) + session := awsAuth.Coalesce(metadata["sessionName"], metadata["awsStsSessionName"], defaultSessionName) // set default if no value is provided + token := awsAuth.Coalesce(metadata["sessionToken"], metadata["awsSessionToken"]) + + if region == "" { + return nil, errors.New("metadata property AWSRegion is missing") + } + + return &awsAuth.DeprecatedKafkaIAM{ + Region: region, + AccessKey: accessKey, + SecretKey: secretKey, + IamRoleArn: role, + StsSessionName: session, + SessionToken: token, + }, nil +} + func (k *Kafka) Close() error { defer k.wg.Wait() defer k.consumerWG.Wait() - errs := make([]error, 2) + errs := make([]error, 3) if k.closed.CompareAndSwap(false, true) { - close(k.closeCh) - - if k.producer != nil { - errs[0] = k.producer.Close() - k.producer = nil + if k.closeCh != nil { + close(k.closeCh) } - if k.internalContext != nil { k.internalContextCancel() } @@ -252,8 +301,19 @@ func (k *Kafka) Close() error { } k.subscribeLock.Unlock() - if k.cg != nil { - errs[1] = k.cg.Close() + if k.clients != nil { + if k.clients.producer != nil { + errs[0] = k.clients.producer.Close() + k.clients.producer = nil + } + if k.clients.consumerGroup != nil { + errs[1] = k.clients.consumerGroup.Close() + k.clients.consumerGroup = nil + } + } + if k.awsAuthProvider != nil { + errs[2] = k.awsAuthProvider.Close() + k.awsAuthProvider = nil } } @@ -323,6 +383,7 @@ func (k *Kafka) getLatestSchema(topic string) (*srclient.Schema, *goavro.Codec, if ok && cacheEntry.expirationTime.After(time.Now()) { return cacheEntry.schema, cacheEntry.codec, nil } + k.logger.Debugf("Cache not found or expired for subject %s. Fetching from registry...", subject) schema, errSchema := srClient.GetLatestSchema(subject) if errSchema != nil { return nil, nil, errSchema diff --git a/common/component/kafka/kafka_test.go b/common/component/kafka/kafka_test.go index 3fbe8c7a2e..cc381ed5a4 100644 --- a/common/component/kafka/kafka_test.go +++ b/common/component/kafka/kafka_test.go @@ -3,6 +3,7 @@ package kafka import ( "encoding/binary" "encoding/json" + "errors" "testing" "time" @@ -10,9 +11,12 @@ import ( gomock "github.com/golang/mock/gomock" "github.com/linkedin/goavro/v2" "github.com/riferrei/srclient" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" mock_srclient "github.com/dapr/components-contrib/common/component/kafka/mocks" + "github.com/dapr/kit/logger" ) func TestGetValueSchemaType(t *testing.T) { @@ -62,6 +66,7 @@ func TestDeserializeValue(t *testing.T) { k := Kafka{ srClient: registry, schemaCachingEnabled: true, + logger: logger.NewLogger("kafka_test"), } schemaIDBytes := make([]byte, 4) @@ -175,6 +180,7 @@ func TestSerializeValueCachingDisabled(t *testing.T) { k := Kafka{ srClient: registry, schemaCachingEnabled: false, + logger: logger.NewLogger("kafka_test"), } t.Run("valueSchemaType not set, leave value as is", func(t *testing.T) { @@ -250,6 +256,7 @@ func TestSerializeValueCachingEnabled(t *testing.T) { schemaCachingEnabled: true, latestSchemaCache: make(map[string]SchemaCacheEntry), latestSchemaCacheTTL: time.Minute * 5, + logger: logger.NewLogger("kafka_test"), } t.Run("valueSchemaType not set, leave value as is", func(t *testing.T) { @@ -280,6 +287,7 @@ func TestLatestSchemaCaching(t *testing.T) { schemaCachingEnabled: true, latestSchemaCache: make(map[string]SchemaCacheEntry), latestSchemaCacheTTL: time.Second * 10, + logger: logger.NewLogger("kafka_test"), } m.EXPECT().GetLatestSchema(gomock.Eq("my-topic-value")).Return(schema, nil).Times(1) @@ -302,6 +310,7 @@ func TestLatestSchemaCaching(t *testing.T) { schemaCachingEnabled: true, latestSchemaCache: make(map[string]SchemaCacheEntry), latestSchemaCacheTTL: time.Second * 1, + logger: logger.NewLogger("kafka_test"), } m.EXPECT().GetLatestSchema(gomock.Eq("my-topic-value")).Return(schema, nil).Times(2) @@ -326,6 +335,7 @@ func TestLatestSchemaCaching(t *testing.T) { schemaCachingEnabled: false, latestSchemaCache: make(map[string]SchemaCacheEntry), latestSchemaCacheTTL: 0, + logger: logger.NewLogger("kafka_test"), } m.EXPECT().GetLatestSchema(gomock.Eq("my-topic-value")).Return(schema, nil).Times(2) @@ -344,3 +354,81 @@ func TestLatestSchemaCaching(t *testing.T) { require.NoError(t, err) }) } + +func TestValidateAWS(t *testing.T) { + tests := []struct { + name string + metadata map[string]string + expected *awsAuth.DeprecatedKafkaIAM + err error + }{ + { + name: "Valid metadata with all fields without aws prefix", + metadata: map[string]string{ + "region": "us-east-1", + "accessKey": "testAccessKey", + "secretKey": "testSecretKey", + "assumeRoleArn": "testRoleArn", + "sessionName": "testSessionName", + "sessionToken": "testSessionToken", + }, + expected: &awsAuth.DeprecatedKafkaIAM{ + Region: "us-east-1", + AccessKey: "testAccessKey", + SecretKey: "testSecretKey", + IamRoleArn: "testRoleArn", + StsSessionName: "testSessionName", + SessionToken: "testSessionToken", + }, + err: nil, + }, + { + name: "Fallback to aws-prefixed fields with aws prefix", + metadata: map[string]string{ + "awsRegion": "us-west-2", + "awsAccessKey": "awsAccessKey", + "awsSecretKey": "awsSecretKey", + "awsIamRoleArn": "awsRoleArn", + "awsStsSessionName": "awsSessionName", + "awsSessionToken": "awsSessionToken", + }, + expected: &awsAuth.DeprecatedKafkaIAM{ + Region: "us-west-2", + AccessKey: "awsAccessKey", + SecretKey: "awsSecretKey", + IamRoleArn: "awsRoleArn", + StsSessionName: "awsSessionName", + SessionToken: "awsSessionToken", + }, + err: nil, + }, + { + name: "Missing region field", + metadata: map[string]string{ + "accessKey": "key", + "secretKey": "secret", + }, + expected: nil, + err: errors.New("metadata property AWSRegion is missing"), + }, + { + name: "Empty metadata", + metadata: map[string]string{}, + expected: nil, + err: errors.New("metadata property AWSRegion is missing"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &Kafka{} + result, err := k.ValidateAWS(tt.metadata) + if tt.err != nil { + require.EqualError(t, err, tt.err.Error()) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/common/component/kafka/metadata.go b/common/component/kafka/metadata.go index c4d0a6bb56..7122feb813 100644 --- a/common/component/kafka/metadata.go +++ b/common/component/kafka/metadata.go @@ -97,14 +97,7 @@ type KafkaMetadata struct { ClientConnectionTopicMetadataRefreshInterval time.Duration `mapstructure:"clientConnectionTopicMetadataRefreshInterval"` ClientConnectionKeepAliveInterval time.Duration `mapstructure:"clientConnectionKeepAliveInterval"` - // aws iam auth profile - AWSAccessKey string `mapstructure:"awsAccessKey"` - AWSSecretKey string `mapstructure:"awsSecretKey"` - AWSSessionToken string `mapstructure:"awsSessionToken"` - AWSIamRoleArn string `mapstructure:"awsIamRoleArn"` - AWSStsSessionName string `mapstructure:"awsStsSessionName"` - AWSRegion string `mapstructure:"awsRegion"` - channelBufferSize int `mapstructure:"-"` + channelBufferSize int `mapstructure:"-"` consumerFetchMin int32 `mapstructure:"-"` consumerFetchDefault int32 `mapstructure:"-"` @@ -163,6 +156,8 @@ func (k *Kafka) getKafkaMetadata(meta map[string]string) (*KafkaMetadata, error) ClientConnectionKeepAliveInterval: defaultClientConnectionKeepAliveInterval, HeartbeatInterval: 3 * time.Second, SessionTimeout: 10 * time.Second, + SchemaCachingEnabled: true, + SchemaLatestVersionCacheTTL: 5 * time.Minute, EscapeHeaders: false, } @@ -265,9 +260,6 @@ func (k *Kafka) getKafkaMetadata(meta map[string]string) (*KafkaMetadata, error) } k.logger.Debug("Configuring root certificate authentication.") case awsIAMAuthType: - if m.AWSRegion == "" { - return nil, errors.New("missing AWS region property 'awsRegion' for authType 'awsiam'") - } k.logger.Debug("Configuring AWS IAM authentication.") default: return nil, errors.New("kafka error: invalid value for 'authType' attribute") diff --git a/common/component/kafka/metadata_test.go b/common/component/kafka/metadata_test.go index e53b1721d8..acb6eb1208 100644 --- a/common/component/kafka/metadata_test.go +++ b/common/component/kafka/metadata_test.go @@ -376,20 +376,6 @@ func TestTls(t *testing.T) { }) } -func TestAwsIam(t *testing.T) { - k := getKafka() - - t.Run("missing aws region", func(t *testing.T) { - m := getBaseMetadata() - m[authType] = awsIAMAuthType - meta, err := k.getKafkaMetadata(m) - require.Error(t, err) - require.Nil(t, meta) - - require.Equal(t, "missing AWS region property 'awsRegion' for authType 'awsiam'", err.Error()) - }) -} - func TestMetadataConsumerFetchValues(t *testing.T) { k := getKafka() m := getCompleteMetadata() diff --git a/common/component/kafka/producer.go b/common/component/kafka/producer.go index 0083cb86a8..97e5a6bbed 100644 --- a/common/component/kafka/producer.go +++ b/common/component/kafka/producer.go @@ -16,6 +16,7 @@ package kafka import ( "context" "errors" + "fmt" "maps" "github.com/IBM/sarama" @@ -23,7 +24,7 @@ import ( "github.com/dapr/components-contrib/pubsub" ) -func getSyncProducer(config sarama.Config, brokers []string, maxMessageBytes int) (sarama.SyncProducer, error) { +func GetSyncProducer(config sarama.Config, brokers []string, maxMessageBytes int) (sarama.SyncProducer, error) { // Add SyncProducer specific properties to copy of base config config.Producer.RequiredAcks = sarama.WaitForAll config.Producer.Retry.Max = 5 @@ -33,7 +34,12 @@ func getSyncProducer(config sarama.Config, brokers []string, maxMessageBytes int config.Producer.MaxMessageBytes = maxMessageBytes } - producer, err := sarama.NewSyncProducer(brokers, &config) + saramaClient, err := sarama.NewClient(brokers, &config) + if err != nil { + return nil, err + } + + producer, err := sarama.NewSyncProducerFromClient(saramaClient) if err != nil { return nil, err } @@ -43,9 +49,14 @@ func getSyncProducer(config sarama.Config, brokers []string, maxMessageBytes int // Publish message to Kafka cluster. func (k *Kafka) Publish(_ context.Context, topic string, data []byte, metadata map[string]string) error { - if k.producer == nil { + clients, err := k.latestClients() + if err != nil || clients == nil { + return fmt.Errorf("failed to get latest Kafka clients: %w", err) + } + if clients.producer == nil { return errors.New("component is closed") } + // k.logger.Debugf("Publishing topic %v with data: %v", topic, string(data)) k.logger.Debugf("Publishing on topic %v", topic) @@ -73,7 +84,7 @@ func (k *Kafka) Publish(_ context.Context, topic string, data []byte, metadata m }) } - partition, offset, err := k.producer.SendMessage(msg) + partition, offset, err := clients.producer.SendMessage(msg) k.logger.Debugf("Partition: %v, offset: %v", partition, offset) @@ -85,7 +96,12 @@ func (k *Kafka) Publish(_ context.Context, topic string, data []byte, metadata m } func (k *Kafka) BulkPublish(_ context.Context, topic string, entries []pubsub.BulkMessageEntry, metadata map[string]string) (pubsub.BulkPublishResponse, error) { - if k.producer == nil { + clients, err := k.latestClients() + if err != nil || clients == nil { + err = fmt.Errorf("failed to get latest Kafka clients: %w", err) + return pubsub.NewBulkPublishResponse(entries, err), err + } + if clients.producer == nil { err := errors.New("component is closed") return pubsub.NewBulkPublishResponse(entries, err), err } @@ -134,7 +150,7 @@ func (k *Kafka) BulkPublish(_ context.Context, topic string, entries []pubsub.Bu msgs = append(msgs, msg) } - if err := k.producer.SendMessages(msgs); err != nil { + if err := clients.producer.SendMessages(msgs); err != nil { // map the returned error to different entries return k.mapKafkaProducerErrors(err, entries), err } diff --git a/common/component/kafka/producer_test.go b/common/component/kafka/producer_test.go index 3dd1b75a9e..a0769767a0 100644 --- a/common/component/kafka/producer_test.go +++ b/common/component/kafka/producer_test.go @@ -13,17 +13,15 @@ import ( ) func arrangeKafkaWithAssertions(t *testing.T, msgCheckers ...saramamocks.MessageChecker) *Kafka { - cfg := saramamocks.NewTestConfig() - mockProducer := saramamocks.NewSyncProducer(t, cfg) + mockP := saramamocks.NewSyncProducer(t, saramamocks.NewTestConfig()) for _, msgChecker := range msgCheckers { - mockProducer.ExpectSendMessageWithMessageCheckerFunctionAndSucceed(msgChecker) + mockP.ExpectSendMessageWithMessageCheckerFunctionAndSucceed(msgChecker) } return &Kafka{ - producer: mockProducer, - config: cfg, - logger: logger.NewLogger("kafka_test"), + mockProducer: mockP, + logger: logger.NewLogger("kafka_test"), } } diff --git a/common/component/kafka/subscriber.go b/common/component/kafka/subscriber.go index 95bdd5a232..4e1dc7ae8f 100644 --- a/common/component/kafka/subscriber.go +++ b/common/component/kafka/subscriber.go @@ -84,7 +84,14 @@ func (k *Kafka) reloadConsumerGroup() { func (k *Kafka) consume(ctx context.Context, topics []string, consumer *consumer) { for { - err := k.cg.Consume(ctx, topics, consumer) + clients, err := k.latestClients() + if err != nil || clients == nil { + k.logger.Errorf("failed to get latest Kafka clients: %w", err) + } + if clients.consumerGroup == nil { + k.logger.Errorf("component is closed") + } + err = clients.consumerGroup.Consume(ctx, topics, consumer) if errors.Is(err, context.Canceled) { return } diff --git a/common/component/kafka/subscriber_test.go b/common/component/kafka/subscriber_test.go index 57b87cf4f2..dbfc696341 100644 --- a/common/component/kafka/subscriber_test.go +++ b/common/component/kafka/subscriber_test.go @@ -41,11 +41,11 @@ func Test_reloadConsumerGroup(t *testing.T) { }) k := &Kafka{ - logger: logger.NewLogger("test"), - cg: cg, - subscribeTopics: nil, - closeCh: make(chan struct{}), - consumerCancel: cancel, + logger: logger.NewLogger("test"), + mockConsumerGroup: cg, + subscribeTopics: nil, + closeCh: make(chan struct{}), + consumerCancel: cancel, } k.reloadConsumerGroup() @@ -64,11 +64,11 @@ func Test_reloadConsumerGroup(t *testing.T) { return nil }) k := &Kafka{ - logger: logger.NewLogger("test"), - cg: cg, - consumerCancel: cancel, - closeCh: make(chan struct{}), - subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, + logger: logger.NewLogger("test"), + mockConsumerGroup: cg, + consumerCancel: cancel, + closeCh: make(chan struct{}), + subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, } k.closed.Store(true) @@ -89,11 +89,11 @@ func Test_reloadConsumerGroup(t *testing.T) { return nil }) k := &Kafka{ - logger: logger.NewLogger("test"), - cg: cg, - consumerCancel: nil, - closeCh: make(chan struct{}), - subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, + logger: logger.NewLogger("test"), + mockConsumerGroup: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, } k.reloadConsumerGroup() @@ -114,7 +114,7 @@ func Test_reloadConsumerGroup(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, @@ -146,7 +146,7 @@ func Test_reloadConsumerGroup(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, @@ -174,7 +174,7 @@ func Test_reloadConsumerGroup(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, @@ -210,7 +210,7 @@ func Test_reloadConsumerGroup(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, @@ -248,7 +248,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), consumeRetryInterval: time.Millisecond, @@ -273,7 +273,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), consumeRetryInterval: time.Millisecond, @@ -302,7 +302,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), consumeRetryInterval: time.Millisecond, @@ -340,7 +340,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), consumeRetryInterval: time.Millisecond, @@ -391,7 +391,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: make(TopicHandlerConfig), @@ -421,7 +421,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: make(TopicHandlerConfig), @@ -495,7 +495,7 @@ func Test_Subscribe(t *testing.T) { }) k := &Kafka{ logger: logger.NewLogger("test"), - cg: cg, + mockConsumerGroup: cg, consumerCancel: nil, closeCh: make(chan struct{}), subscribeTopics: make(TopicHandlerConfig), diff --git a/common/component/postgresql/v1/metadata.go b/common/component/postgresql/v1/metadata.go index 72a26a226b..7bef937115 100644 --- a/common/component/postgresql/v1/metadata.go +++ b/common/component/postgresql/v1/metadata.go @@ -39,7 +39,7 @@ type pgMetadata struct { Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"` CleanupInterval *time.Duration `mapstructure:"cleanupInterval" mapstructurealiases:"cleanupIntervalInSeconds"` - aws.AWSIAM `mapstructure:",squash"` + aws.DeprecatedPostgresIAM `mapstructure:",squash"` } func (m *pgMetadata) InitWithMetadata(meta state.Metadata, opts pgauth.InitWithMetadataOpts) error { diff --git a/common/component/postgresql/v1/postgresql.go b/common/component/postgresql/v1/postgresql.go index 3f99559870..636c19a493 100644 --- a/common/component/postgresql/v1/postgresql.go +++ b/common/component/postgresql/v1/postgresql.go @@ -28,6 +28,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" pgauth "github.com/dapr/components-contrib/common/authentication/postgresql" pginterfaces "github.com/dapr/components-contrib/common/component/postgresql/interfaces" pgtransactions "github.com/dapr/components-contrib/common/component/postgresql/transactions" @@ -54,6 +55,8 @@ type PostgreSQL struct { etagColumn string enableAzureAD bool enableAWSIAM bool + + awsAuthProvider awsAuth.Provider } type Options struct { @@ -96,16 +99,31 @@ func (p *PostgreSQL) Init(ctx context.Context, meta state.Metadata) error { AWSIAMEnabled: p.enableAWSIAM, } - err := p.metadata.InitWithMetadata(meta, opts) - if err != nil { + if err := p.metadata.InitWithMetadata(meta, opts); err != nil { return fmt.Errorf("failed to parse metadata: %w", err) } + var err error config, err := p.metadata.GetPgxPoolConfig() if err != nil { return err } + if opts.AWSIAMEnabled && p.metadata.UseAWSIAM { + opts, validateErr := p.metadata.BuildAwsIamOptions(p.logger, meta.Properties) + if validateErr != nil { + return fmt.Errorf("failed to validate AWS IAM authentication fields: %w", validateErr) + } + + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, *opts, awsAuth.GetConfig(*opts)) + if err != nil { + return err + } + p.awsAuthProvider = provider + p.awsAuthProvider.UpdatePostgres(ctx, config) + } + connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) p.db, err = pgxpool.NewWithConfig(connCtx, config) connCancel() @@ -491,11 +509,15 @@ func (p *PostgreSQL) Close() error { p.db = nil } + errs := make([]error, 2) if p.gc != nil { - return p.gc.Close() + errs[0] = p.gc.Close() } - return nil + if p.awsAuthProvider != nil { + errs[1] = p.awsAuthProvider.Close() + } + return errors.Join(errs...) } // GetCleanupInterval returns the cleanupInterval property. diff --git a/common/component/redis/redis.go b/common/component/redis/redis.go index 818a4e6a1a..f08d3c7c1e 100644 --- a/common/component/redis/redis.go +++ b/common/component/redis/redis.go @@ -214,21 +214,21 @@ func ParseClientFromProperties(properties map[string]string, componentType metad // start the token refresh goroutine if settings.UseEntraID { - StartEntraIDTokenRefreshBackgroundRoutine(c, settings.Username, *tokenExpires, tokenCredential, ctx, logger) + StartEntraIDTokenRefreshBackgroundRoutine(c, settings.Username, *tokenExpires, tokenCredential, logger) } return c, &settings, nil } -func StartEntraIDTokenRefreshBackgroundRoutine(client RedisClient, username string, nextExpiration time.Time, cred *azcore.TokenCredential, parentCtx context.Context, logger *kitlogger.Logger) { +func StartEntraIDTokenRefreshBackgroundRoutine(client RedisClient, username string, nextExpiration time.Time, cred *azcore.TokenCredential, logger *kitlogger.Logger) { go func(cred *azcore.TokenCredential, username string, logger *kitlogger.Logger) { - ctx, cancel := context.WithCancel(parentCtx) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() backoffConfig := kitretry.DefaultConfig() backoffConfig.MaxRetries = 3 backoffConfig.Policy = kitretry.PolicyExponential var backoffManager backoff.BackOff - const refreshGracePeriod = 2 * time.Minute + const refreshGracePeriod = 5 * time.Minute tokenRefreshDuration := time.Until(nextExpiration.Add(-refreshGracePeriod)) (*logger).Debugf("redis client: starting entraID token refresh loop") diff --git a/configuration/postgres/metadata.go b/configuration/postgres/metadata.go index 3ca391d3a8..9ec52a0908 100644 --- a/configuration/postgres/metadata.go +++ b/configuration/postgres/metadata.go @@ -32,7 +32,7 @@ type metadata struct { Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"` ConfigTable string `mapstructure:"table"` MaxIdleTimeoutOld time.Duration `mapstructure:"connMaxIdleTime"` // Deprecated alias for "connectionMaxIdleTime" - aws.AWSIAM `mapstructure:",squash"` + aws.DeprecatedPostgresIAM `mapstructure:",squash"` } func (m *metadata) InitWithMetadata(meta map[string]string) error { diff --git a/configuration/postgres/metadata.yaml b/configuration/postgres/metadata.yaml index 401d704e30..16db28f2ee 100644 --- a/configuration/postgres/metadata.yaml +++ b/configuration/postgres/metadata.yaml @@ -46,23 +46,21 @@ builtinAuthenticationProfiles: example: | "host=mydb.postgres.database.aws.com user=myapplication port=5432 dbname=dapr_test sslmode=require" type: string - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the AWS Relational Database Service is deployed to. - example: '"us-east-1"' - name: awsAccessKey type: string - required: true + required: false description: | + Deprecated as of Dapr 1.17. Use 'accessKey' instead if using AWS IAM. + If both fields are set, then 'accessKey' value will be used. AWS access key associated with an IAM account. example: '"AKIAIOSFODNN7EXAMPLE"' - name: awsSecretKey type: string - required: true + required: false sensitive: true description: | + Deprecated as of Dapr 1.17. Use 'secretKey' instead if using AWS IAM. + If both fields are set, then 'secretKey' value will be used. The secret key associated with the access key. example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' authenticationProfiles: diff --git a/configuration/postgres/postgres.go b/configuration/postgres/postgres.go index 63119ab7d2..734a199f0d 100644 --- a/configuration/postgres/postgres.go +++ b/configuration/postgres/postgres.go @@ -31,6 +31,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + pgauth "github.com/dapr/components-contrib/common/authentication/postgresql" "github.com/dapr/components-contrib/configuration" contribMetadata "github.com/dapr/components-contrib/metadata" "github.com/dapr/kit/logger" @@ -47,6 +49,10 @@ type ConfigurationStore struct { wg sync.WaitGroup closed atomic.Bool lock sync.RWMutex + + enableAzureAD bool + enableAWSIAM bool + awsAuthProvider awsAuth.Provider } type subscription struct { @@ -77,6 +83,10 @@ func NewPostgresConfigurationStore(logger logger.Logger) configuration.Store { } func (p *ConfigurationStore) Init(ctx context.Context, metadata configuration.Metadata) error { + opts := pgauth.InitWithMetadataOpts{ + AzureADEnabled: p.enableAzureAD, + AWSIAMEnabled: p.enableAWSIAM, + } err := p.metadata.InitWithMetadata(metadata.Properties) if err != nil { p.logger.Error(err) @@ -84,10 +94,36 @@ func (p *ConfigurationStore) Init(ctx context.Context, metadata configuration.Me } p.ActiveSubscriptions = make(map[string]*subscription) - p.client, err = p.connectDB(ctx) + config, err := p.metadata.GetPgxPoolConfig() + if err != nil { + return fmt.Errorf("PostgreSQL configuration store connection error: %s", err) + } + + if opts.AWSIAMEnabled && p.metadata.UseAWSIAM { + opts, validateErr := p.metadata.BuildAwsIamOptions(p.logger, metadata.Properties) + if validateErr != nil { + return fmt.Errorf("failed to validate AWS IAM authentication fields: %w", validateErr) + } + + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, *opts, awsAuth.GetConfig(*opts)) + if err != nil { + return err + } + p.awsAuthProvider = provider + p.awsAuthProvider.UpdatePostgres(ctx, config) + } + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return fmt.Errorf("PostgreSQL configuration store connection error: %w", err) + } + + err = pool.Ping(ctx) if err != nil { - return fmt.Errorf("error connecting to configuration store: '%w'", err) + return fmt.Errorf("PostgreSQL configuration store ping error: %w", err) } + p.client = pool err = p.client.Ping(ctx) if err != nil { @@ -304,25 +340,6 @@ func (p *ConfigurationStore) handleSubscribedChange(ctx context.Context, handler } } -func (p *ConfigurationStore) connectDB(ctx context.Context) (*pgxpool.Pool, error) { - config, err := p.metadata.GetPgxPoolConfig() - if err != nil { - return nil, fmt.Errorf("PostgreSQL configuration store connection error: %s", err) - } - - pool, err := pgxpool.NewWithConfig(ctx, config) - if err != nil { - return nil, fmt.Errorf("PostgreSQL configuration store connection error: %w", err) - } - - err = pool.Ping(ctx) - if err != nil { - return nil, fmt.Errorf("PostgreSQL configuration store ping error: %w", err) - } - - return pool, nil -} - func buildQuery(req *configuration.GetRequest, configTable string) (string, []interface{}, error) { var query string var params []interface{} @@ -436,5 +453,9 @@ func (p *ConfigurationStore) Close() error { p.client.Close() } - return nil + errs := make([]error, 1) + if p.awsAuthProvider != nil { + errs[0] = p.awsAuthProvider.Close() + } + return errors.Join(errs...) } diff --git a/go.mod b/go.mod index ea0cb75bd0..834e4bfc31 100644 --- a/go.mod +++ b/go.mod @@ -39,13 +39,15 @@ require ( github.com/apache/pulsar-client-go v0.11.0 github.com/apache/rocketmq-client-go/v2 v2.1.2-0.20230412142645-25003f6f083d github.com/apache/thrift v0.13.0 - github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 + github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a github.com/aws/aws-sdk-go v1.50.19 - github.com/aws/aws-sdk-go-v2 v1.31.0 - github.com/aws/aws-sdk-go-v2/config v1.27.39 - github.com/aws/aws-sdk-go-v2/credentials v1.17.37 + github.com/aws/aws-sdk-go-v2 v1.32.4 + github.com/aws/aws-sdk-go-v2/config v1.28.2 + github.com/aws/aws-sdk-go-v2/credentials v1.17.43 github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.32.4 + github.com/aws/rolesanywhere-credential-helper v1.0.4 github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 github.com/camunda/zeebe/clients/go/v8 v8.2.12 github.com/cenkalti/backoff/v4 v4.2.1 @@ -107,6 +109,7 @@ require ( github.com/sendgrid/sendgrid-go v3.13.0+incompatible github.com/sijms/go-ora/v2 v2.7.18 github.com/spf13/cast v1.5.1 + github.com/spiffe/go-spiffe/v2 v2.1.7 github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 github.com/stretchr/testify v1.9.0 github.com/supplyon/gremcos v0.1.40 @@ -179,16 +182,15 @@ require ( github.com/armon/go-metrics v0.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect - github.com/aws/smithy-go v1.21.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4 // indirect + github.com/aws/smithy-go v1.22.0 // indirect github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -380,6 +382,7 @@ require ( github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect + github.com/zeebo/errs v1.3.0 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.opencensus.io v0.24.0 // indirect @@ -403,6 +406,7 @@ require ( google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/couchbase/gocbcore.v7 v7.1.18 // indirect gopkg.in/couchbaselabs/gocbconnstr.v1 v1.0.4 // indirect diff --git a/go.sum b/go.sum index a9538f188f..8b0dccd210 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXY github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d h1:wvStE9wLpws31NiWUx+38wny1msZ/tm+eL5xmm4Y7So= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d/go.mod h1:9XMFaCeRyW7fC9XJOWQ+NdAv8VLG7ys7l3x4ozEGLUQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -242,8 +244,8 @@ github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:o github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= -github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 h1:UyjtGmO0Uwl/K+zpzPwLoXzMhcN9xmnR2nrqJoBrg3c= -github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0/go.mod h1:TJAXuFs2HcMib3sN5L0gUC+Q01Qvy3DemvA55WuC+iA= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a h1:QFemvMGPnajaeRBkFc1HoEA7qzVjUv+rkYb1/ps1/UE= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a/go.mod h1:MVYeeOhILFFemC/XlYTClvBjYZrg/EPd3ts885KrNTI= github.com/aws/aws-sdk-go v1.19.48/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= @@ -251,47 +253,49 @@ github.com/aws/aws-sdk-go v1.50.19 h1:YSIDKRSkh/TW0RPWoocdLqtC/T5W6IGBVhFs6P7Qca github.com/aws/aws-sdk-go v1.50.19/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.9.2/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= -github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U= -github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA= +github.com/aws/aws-sdk-go-v2 v1.32.4 h1:S13INUiTxgrPueTmrm5DZ+MiAo99zYzHEFh1UNkOxNE= +github.com/aws/aws-sdk-go-v2 v1.32.4/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 h1:xDAuZTn4IMm8o1LnBZvmrL8JA1io4o3YWNXgohbf20g= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5/go.mod h1:wYSv6iDS621sEFLfKvpPE2ugjTuGlAG7iROg0hLOkfc= github.com/aws/aws-sdk-go-v2/config v1.8.3/go.mod h1:4AEiLtAb8kLs7vgw2ZV3p2VZ1+hBavOc84hqxVNpCyw= -github.com/aws/aws-sdk-go-v2/config v1.27.39 h1:FCylu78eTGzW1ynHcongXK9YHtoXD5AiiUqq3YfJYjU= -github.com/aws/aws-sdk-go-v2/config v1.27.39/go.mod h1:wczj2hbyskP4LjMKBEZwPRO1shXY+GsQleab+ZXT2ik= +github.com/aws/aws-sdk-go-v2/config v1.28.2 h1:FLvWA97elBiSPdIol4CXfIAY1wlq3KzoSgkMuZSuSe8= +github.com/aws/aws-sdk-go-v2/config v1.28.2/go.mod h1:hNmQsKfUqpKz2yfnZUB60GCemPmeqAalVTui0gOxjAE= github.com/aws/aws-sdk-go-v2/credentials v1.4.3/go.mod h1:FNNC6nQZQUuyhq5aE5c7ata8o9e4ECGmS4lAXC7o1mQ= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37 h1:G2aOH01yW8X373JK419THj5QVqu9vKEwxSEsGxihoW0= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37/go.mod h1:0ecCjlb7htYCptRD45lXJ6aJDQac6D2NlKGpZqyTG6A= +github.com/aws/aws-sdk-go-v2/credentials v1.17.43 h1:SEGdVOOE1Wyr2XFKQopQ5GYjym3nYHcphesdt78rNkY= +github.com/aws/aws-sdk-go-v2/credentials v1.17.43/go.mod h1:3aiza5kSyAE4eujSanOkSkAmX/RnVqslM+GRQ/Xvv4c= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.6.0/go.mod h1:gqlclDEZp4aqJOancXK6TN24aKhT0W0Ae9MHk3wzTMM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 h1:C/d03NAmh8C4BZXhuRNboF/DqhBkBCeDiJDcaqIT5pA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14/go.mod h1:7I0Ju7p9mCIdlrfS+JCgqcYD0VXz/N4yozsox+0o078= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 h1:woXadbf0c7enQ2UGCi8gW/WuKmE0xIzxBF/eD94jMKQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19/go.mod h1:zminj5ucw7w0r65bP6nhyOd3xL6veAUMc3ElGMoLVb4= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 h1:z6fAXB4HSuYjrE/P8RU3NdCaN+EPaeq/+80aisCjuF8= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10/go.mod h1:PoPjOi7j+/DtKIGC58HRfcdWKBPYYXwdKnRG+po+hzo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 h1:A2w6m6Tmr+BNXjDsr7M90zkWjsu4JXHwrzPg235STs4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23/go.mod h1:35EVp9wyeANdujZruvHiQUAo9E3vbhnIO1mTCAxMlY0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 h1:pgYW9FCabt2M25MoHYCfMrVY2ghiiBKYWUVXfwZs+sU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23/go.mod h1:c48kLgzO19wAu3CPkDWC28JbaJ+hfQlsdl7I2+oqIbk= github.com/aws/aws-sdk-go-v2/internal/ini v1.2.4/go.mod h1:ZcBrrI3zBKlhGFNYWvju0I3TR93I7YIgAfy82Fh4lcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/appconfig v1.4.2/go.mod h1:FZ3HkCe+b10uFZZkFdvf98LHW21k49W8o8J366lqVKY= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3 h1:PtP2Zzf3uy94EsVOW+tB7gNt63fFZEHuS9IRWg5q250= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3/go.mod h1:4zuvYEUJm0Vq8tb3gcb2sl04A9I1AA5DKAefbYPA4VM= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 h1:QFASJGfT8wMXtuP3D5CRmMjARHv9ZmzFUMJznHDOY3w= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5/go.mod h1:QdZ3OmoIjSX+8D1OPAzPxDfjXASbBMDsz9qvtyIhtik= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.2/go.mod h1:72HRZDLMtmVQiLG2tLfQcaWLCssELvGl+Zf2WVxMmR8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 h1:Xbwbmk44URTiHNx6PNo0ujDE6ERlsCKJD3u1zfnzAPg= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20/go.mod h1:oAfOFzUB14ltPZj1rWwRc3d/6OgD76R8KlvU3EqM9Fg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 h1:tHxQi/XHPK0ctd/wdOw0t7Xrc2OxcRCnVzv8lwWPu0c= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4/go.mod h1:4GQbF1vJzG60poZqWatZlhP31y8PGCCVTvIGPdaaYJ0= github.com/aws/aws-sdk-go-v2/service/sso v1.4.2/go.mod h1:NBvT9R1MEF+Ud6ApJKM0G+IkPchKS7p7c2YPKwHmBOk= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 h1:rs4JCczF805+FDv2tRhZ1NU0RB2H6ryAvsWPanAr72Y= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3/go.mod h1:XRlMvmad0ZNL+75C5FYdMvbbLkd6qiqz6foR1nA1PXY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 h1:S7EPdMVZod8BGKQQPTBK+FcX9g7bKR7c4+HxWqHP7Vg= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDXIzAB9GAwVSzFzSy97uZ3IsHo4E= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.4 h1:BqE3NRG6bsODh++VMKMsDmFuJTHrdD4rJZqHjDeF6XI= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.4/go.mod h1:wrMCEwjFPms+V86TCQQeOxQF/If4vT44FGIOFiMC2ck= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4 h1:zcx9LiGWZ6i6pjdcoE9oXAB6mUdeyC36Ia/QEiIvYdg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4/go.mod h1:Tp/ly1cTjRLGBBmNccFumbZ8oqpZlpdhFf80SrRh4is= github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.4 h1:yDxvkz3/uOKfxnv8YhzOi9m+2OGIxF+on3KOISbK5IU= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.4/go.mod h1:9XEUty5v5UAsMiFOBJrNibZgwCeOma73jgGwwhgffa8= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= -github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= +github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f h1:Pf0BjJDga7C98f0vhw+Ip5EaiE07S3lTKpIYPNS0nMo= github.com/awslabs/kinesis-aggregation/go v0.0.0-20210630091500-54e17340d32f/go.mod h1:SghidfnxvX7ribW6nHI7T+IBbc9puZ9kk5Tx/88h8P4= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= @@ -605,6 +609,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= @@ -1525,6 +1531,8 @@ github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5q github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= +github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= +github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 h1:mTC4gyv3lcJ1XpzZMAckqkvWUqeT5Bva4RAT1IoHAAA= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58/go.mod h1:ZAYCOqLJkc9P6fcq14TV4cf+gJ2fHthp9kCGxBViagE= github.com/stealthrocket/wazergo v0.19.1 h1:BPrITETPgSFwiytwmToO0MbUC/+RGC39JScz1JmmG6c= @@ -1660,6 +1668,8 @@ github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7 github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= +github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= github.com/zouyx/agollo/v3 v3.4.5 h1:7YCxzY9ZYaH9TuVUBvmI6Tk0mwMggikah+cfbYogcHQ= github.com/zouyx/agollo/v3 v3.4.5/go.mod h1:LJr3kDmm23QSW+F1Ol4TMHDa7HvJvscMdVxJ2IpUTVc= go.einride.tech/aip v0.66.0 h1:XfV+NQX6L7EOYK11yoHHFtndeaWh3KbD9/cN/6iWEt8= @@ -2324,6 +2334,8 @@ google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACu google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index db45fb8d84..f79ed207af 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -22,6 +22,7 @@ type snsSqsMetadata struct { // aws endpoint for the component to use. Endpoint string `mapstructure:"endpoint"` // aws region in which SNS/SQS should create resources. + // TODO: rm the alias on region in Dapr 1.17. Region string `json:"region" mapstructure:"region" mapstructurealiases:"awsRegion" mdignore:"true"` // aws partition in which SNS/SQS should create resources. internalPartition string `mapstructure:"-"` @@ -57,6 +58,8 @@ type snsSqsMetadata struct { AccountID string `mapstructure:"accountID"` // processing concurrency mode ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"` + // limits the number of concurrent goroutines + ConcurrencyLimit int `mapstructure:"concurrencyLimit"` } func maskLeft(s string) string { @@ -67,7 +70,7 @@ func maskLeft(s string) string { return string(rs) } -func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, error) { +func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error) { md := &snsSqsMetadata{ AssetsManagementTimeoutSeconds: assetsManagementDefaultTimeoutSeconds, MessageVisibilityTimeout: 10, @@ -130,6 +133,10 @@ func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, erro return nil, err } + if md.ConcurrencyLimit < 0 { + return nil, errors.New("concurrencyLimit must be greater than or equal to 0") + } + s.logger.Debug(md.hideDebugPrintedCredentials()) return md, nil diff --git a/pubsub/aws/snssqs/metadata.yaml b/pubsub/aws/snssqs/metadata.yaml index 641ab4ea07..e121eaf005 100644 --- a/pubsub/aws/snssqs/metadata.yaml +++ b/pubsub/aws/snssqs/metadata.yaml @@ -128,6 +128,15 @@ metadata: default: '"parallel"' example: '"single", "parallel"' type: string + - name: concurrencyLimit + required: false + description: | + Defines the maximum number of concurrent workers handling messages. + This value is ignored when "concurrencyMode" is set to “single“. + To avoid limiting the number of concurrent workers set this to “0“. + type: number + default: '0' + example: '100' - name: accountId required: false description: | diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 357cfcabb9..e84dbeb80a 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -49,9 +49,7 @@ type snsSqs struct { queues map[string]*sqsQueueInfo // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. subscriptions map[string]string - snsClient *sns.SNS - sqsClient *sqs.SQS - stsClient *sts.STS + authProvider awsAuth.Provider metadata *snsSqsMetadata logger logger.Logger id string @@ -138,23 +136,33 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { } func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { - md, err := s.getSnsSqsMetatdata(metadata) + m, err := s.getSnsSqsMetadata(metadata) if err != nil { return err } - s.metadata = md + s.metadata = m - sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) - if err != nil { - return fmt.Errorf("error creating an AWS client: %w", err) + if s.authProvider == nil { + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider } - // AWS sns,sqs,sts client. - s.snsClient = sns.New(sess) - s.sqsClient = sqs.New(sess) - s.stsClient = sts.New(sess) - s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second)) + s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) if err != nil { @@ -181,9 +189,8 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { if len(s.metadata.AccountID) == awsAccountIDLength { return nil } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := s.authProvider.SnsSqs().Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -208,9 +215,8 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} snsCreateTopicInput.SetAttributes(attributes) } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := s.authProvider.SnsSqs().Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -222,7 +228,7 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := s.authProvider.SnsSqs().Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -288,15 +294,16 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} sqsCreateQueueInput.SetAttributes(attributes) } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := s.authProvider.SnsSqs().Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -313,7 +320,7 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -321,7 +328,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -382,7 +389,7 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := s.authProvider.SnsSqs().Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -402,7 +409,7 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := s.authProvider.SnsSqs().Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -451,7 +458,7 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) - _, err := s.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err := s.authProvider.SnsSqs().Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -466,7 +473,7 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err := s.authProvider.SnsSqs().Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -588,17 +595,23 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds), } + // sem is a semaphore used to control the concurrencyLimit. + // It is set only when we are in parallel mode and limit is > 0. + var sem chan (struct{}) = nil + if (s.metadata.ConcurrencyMode == pubsub.Parallel) && s.metadata.ConcurrencyLimit > 0 { + sem = make(chan struct{}, s.metadata.ConcurrencyLimit) + } + for { // If the context is canceled, stop requesting messages if ctx.Err() != nil { break } - - // Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact + // Internally, by default, aws go sdk performs 3 retries with exponential backoff to contact // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := s.authProvider.SnsSqs().Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -623,7 +636,6 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters } s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn) - var wg sync.WaitGroup for _, message := range messageResponse.Messages { if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil { s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err) @@ -631,25 +643,30 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters } f := func(message *sqs.Message) { - defer wg.Done() if err := s.callHandler(ctx, message, queueInfo); err != nil { s.logger.Errorf("error while handling received message. error is: %v", err) } } - wg.Add(1) switch s.metadata.ConcurrencyMode { case pubsub.Single: f(message) case pubsub.Parallel: - wg.Add(1) + // This is the back pressure mechanism. + // It will block until another goroutine frees a slot. + if sem != nil { + sem <- struct{}{} + } + go func(message *sqs.Message) { - defer wg.Done() + if sem != nil { + defer func() { <-sem }() + } + f(message) }(message) } } - wg.Wait() } } @@ -690,9 +707,8 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI return wrappedErr } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -712,7 +728,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -739,7 +755,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, } ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = s.sqsClient.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -852,7 +868,7 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error } // sns client has internal exponential backoffs. - _, err = s.snsClient.PublishWithContext(ctx, snsPublishInput) + _, err = s.authProvider.SnsSqs().Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) @@ -870,6 +886,9 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } + if s.authProvider != nil { + return s.authProvider.Close() + } return nil } diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 1c789b67be..82f58bcbff 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -38,7 +38,7 @@ func Test_parseTopicArn(t *testing.T) { } // Verify that all metadata ends up in the correct spot. -func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { +func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -47,10 +47,11 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "Endpoint": "endpoint", "concurrencyMode": string(pubsub.Single), + "concurrencyLimit": "42", "accessKey": "a", "secretKey": "s", "sessionToken": "t", @@ -68,6 +69,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { r.Equal("consumer", md.SqsQueueName) r.Equal("endpoint", md.Endpoint) r.Equal(pubsub.Single, md.ConcurrencyMode) + r.Equal(42, md.ConcurrencyLimit) r.Equal("a", md.AccessKey) r.Equal("s", md.SecretKey) r.Equal("t", md.SessionToken) @@ -80,7 +82,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { r.Equal(int64(6), md.MessageReceiveLimit) } -func Test_getSnsSqsMetatdata_defaults(t *testing.T) { +func Test_getSnsSqsMetadata_defaults(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -89,7 +91,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -105,6 +107,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { r.Equal("", md.SessionToken) r.Equal("r", md.Region) r.Equal(pubsub.Parallel, md.ConcurrencyMode) + r.Equal(0, md.ConcurrencyLimit) r.Equal(int64(10), md.MessageVisibilityTimeout) r.Equal(int64(10), md.MessageRetryLimit) r.Equal(int64(2), md.MessageWaitTimeSeconds) @@ -114,7 +117,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { r.False(md.DisableDeleteOnRetryLimit) } -func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { +func Test_getSnsSqsMetadata_legacyaliases(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -123,7 +126,7 @@ func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "awsAccountID": "acctId", "awsSecret": "secret", @@ -151,13 +154,13 @@ func testMetadataParsingShouldFail(t *testing.T, metadata pubsub.Metadata, l log logger: l, } - md, err := ps.getSnsSqsMetatdata(metadata) + md, err := ps.getSnsSqsMetadata(metadata) r.Error(err) r.Nil(md) } -func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) { +func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) { t.Parallel() fixtures := []testUnitFixture{ @@ -273,6 +276,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) { }}}, name: "invalid message concurrencyMode", }, + // invalid concurrencyLimit + { + metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + "consumerID": "consumer", + "Endpoint": "endpoint", + "AccessKey": "acctId", + "SecretKey": "secret", + "awsToken": "token", + "Region": "region", + "messageRetryLimit": "10", + "concurrencyLimit": "-1", + }}}, + name: "invalid message concurrencyLimit", + }, } l := logger.NewLogger("SnsSqs unit test") @@ -432,7 +449,7 @@ func Test_buildARN_DefaultPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -455,7 +472,7 @@ func Test_buildARN_StandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -478,7 +495,7 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", diff --git a/pubsub/azure/eventhubs/eventhubs.go b/pubsub/azure/eventhubs/eventhubs.go index 65617239bf..ae6db85377 100644 --- a/pubsub/azure/eventhubs/eventhubs.go +++ b/pubsub/azure/eventhubs/eventhubs.go @@ -130,6 +130,10 @@ func (aeh *AzureEventHubs) Subscribe(ctx context.Context, req pubsub.SubscribeRe // Check if requireAllProperties is set and is truthy getAllProperties := utils.IsTruthy(req.Metadata["requireAllProperties"]) + if !getAllProperties { + getAllProperties = aeh.GetAllMessageProperties() + } + checkPointFrequencyPerPartition := commonutils.GetIntValFromString(req.Metadata["checkPointFrequencyPerPartition"], impl.DefaultCheckpointFrequencyPerPartition) pubsubHandler := aeh.GetPubSubHandlerFunc(topic, getAllProperties, handler) @@ -155,6 +159,9 @@ func (aeh *AzureEventHubs) BulkSubscribe(ctx context.Context, req pubsub.Subscri // Check if requireAllProperties is set and is truthy getAllProperties := utils.IsTruthy(req.Metadata["requireAllProperties"]) + if !getAllProperties { + getAllProperties = aeh.GetAllMessageProperties() + } checkPointFrequencyPerPartition := commonutils.GetIntValFromString(req.Metadata["checkPointFrequencyPerPartition"], impl.DefaultCheckpointFrequencyPerPartition) maxBulkSubCount := commonutils.GetIntValOrDefault(req.BulkSubscribeConfig.MaxMessagesCount, impl.DefaultMaxBulkSubCount) maxBulkSubAwaitDurationMs := commonutils.GetIntValOrDefault(req.BulkSubscribeConfig.MaxAwaitDurationMs, impl.DefaultMaxBulkSubAwaitDurationMs) diff --git a/pubsub/azure/eventhubs/metadata.yaml b/pubsub/azure/eventhubs/metadata.yaml index 768d472252..0f56fc0932 100644 --- a/pubsub/azure/eventhubs/metadata.yaml +++ b/pubsub/azure/eventhubs/metadata.yaml @@ -35,6 +35,13 @@ builtinAuthenticationProfiles: example: "false" description: | Allow management of the Event Hub namespace and storage account. + - name: enableInOrderMessageDelivery + type: bool + required: false + default: "false" + example: "false" + description: | + Enable in order processing of messages within a partition. # The following four properties are needed only if enableEntityManagement is set to true - name: resourceGroupName @@ -103,3 +110,12 @@ metadata: description: | The name of the Event Hubs Consumer Group to listen on. example: '"group1"' + - name: getAllMessageProperties + required: false + default: "false" + example: "false" + binding: + input: true + output: false + description: | + When set to true, will retrieve all message properties and include them in the returned event metadata diff --git a/pubsub/kafka/metadata.yaml b/pubsub/kafka/metadata.yaml index b6536c2b43..566c9a58c3 100644 --- a/pubsub/kafka/metadata.yaml +++ b/pubsub/kafka/metadata.yaml @@ -8,6 +8,67 @@ title: "Apache Kafka" urls: - title: Reference url: https://docs.dapr.io/reference/components-reference/supported-pubsub/setup-apache-kafka/ +# This auth profile has duplicate fields intentionally as we maintain backwards compatibility, +# but also move Kafka to utilize the noramlized AWS fields in the builtin auth profiles. +# TODO: rm the duplicate aws prefixed fields in Dapr 1.17. +builtinAuthenticationProfiles: + - name: "aws" + metadata: + - name: authType + type: string + required: true + description: | + Authentication type. + This must be set to "awsiam" for this authentication profile. + example: '"awsiam"' + allowedValues: + - "awsiam" + - name: awsAccessKey + type: string + required: false + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'accessKey' instead. + If both fields are set, then 'accessKey' value will be used. + AWS access key associated with an IAM account. + example: '"AKIAIOSFODNN7EXAMPLE"' + - name: awsSecretKey + type: string + required: false + sensitive: true + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'secretKey' instead. + If both fields are set, then 'secretKey' value will be used. + The secret key associated with the access key. + example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' + - name: awsSessionToken + type: string + sensitive: true + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'sessionToken' instead. + If both fields are set, then 'sessionToken' value will be used. + AWS session token to use. A session token is only required if you are using temporary security credentials. + example: '"TOKEN"' + - name: awsIamRoleArn + type: string + required: false + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'assumeRoleArn' instead. + If both fields are set, then 'assumeRoleArn' value will be used. + IAM role that has access to MSK. This is another option to authenticate with MSK aside from the AWS Credentials. + example: '"arn:aws:iam::123456789:role/mskRole"' + - name: awsStsSessionName + type: string + description: | + This maintains backwards compatibility with existing fields. + It will be deprecated as of Dapr 1.17. Use 'sessionName' instead. + If both fields are set, then 'sessionName' value will be used. + Represents the session name for assuming a role. + example: '"MyAppSession"' + default: '"DaprDefaultSession"' authenticationProfiles: - title: "OIDC Authentication" description: | @@ -133,55 +194,6 @@ authenticationProfiles: example: '"none"' allowedValues: - "none" - - title: "AWS IAM" - description: "Authenticate using AWS IAM credentials or role for AWS MSK" - metadata: - - name: authType - type: string - required: true - description: | - Authentication type. - This must be set to "awsiam" for this authentication profile. - example: '"awsiam"' - allowedValues: - - "awsiam" - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the MSK Kafka broker is deployed to. - example: '"us-east-1"' - - name: awsAccessKey - type: string - required: true - description: | - AWS access key associated with an IAM account. - example: '"AKIAIOSFODNN7EXAMPLE"' - - name: awsSecretKey - type: string - required: true - sensitive: true - description: | - The secret key associated with the access key. - example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' - - name: awsSessionToken - type: string - sensitive: true - description: | - AWS session token to use. A session token is only required if you are using temporary security credentials. - example: '"TOKEN"' - - name: awsIamRoleArn - type: string - required: true - description: | - IAM role that has access to MSK. This is another option to authenticate with MSK aside from the AWS Credentials. - example: '"arn:aws:iam::123456789:role/mskRole"' - - name: awsStsSessionName - type: string - description: | - Represents the session name for assuming a role. - example: '"MyAppSession"' - default: '"MSKSASLDefaultSession"' metadata: - name: brokers type: string diff --git a/pubsub/pulsar/metadata.go b/pubsub/pulsar/metadata.go index 62b3b06bbc..ba5067796b 100644 --- a/pubsub/pulsar/metadata.go +++ b/pubsub/pulsar/metadata.go @@ -20,25 +20,25 @@ import ( ) type pulsarMetadata struct { - Host string `mapstructure:"host"` - ConsumerID string `mapstructure:"consumerID"` - EnableTLS bool `mapstructure:"enableTLS"` - DisableBatching bool `mapstructure:"disableBatching"` - BatchingMaxPublishDelay time.Duration `mapstructure:"batchingMaxPublishDelay"` - BatchingMaxSize uint `mapstructure:"batchingMaxSize"` - BatchingMaxMessages uint `mapstructure:"batchingMaxMessages"` - Tenant string `mapstructure:"tenant"` - Namespace string `mapstructure:"namespace"` - Persistent bool `mapstructure:"persistent"` - RedeliveryDelay time.Duration `mapstructure:"redeliveryDelay"` - internalTopicSchemas map[string]schemaMetadata `mapstructure:"-"` - PublicKey string `mapstructure:"publicKey"` - PrivateKey string `mapstructure:"privateKey"` - Keys string `mapstructure:"keys"` - MaxConcurrentHandlers uint `mapstructure:"maxConcurrentHandlers"` - ReceiverQueueSize int `mapstructure:"receiverQueueSize"` - - Token string `mapstructure:"token"` + Host string `mapstructure:"host"` + ConsumerID string `mapstructure:"consumerID"` + EnableTLS bool `mapstructure:"enableTLS"` + DisableBatching bool `mapstructure:"disableBatching"` + BatchingMaxPublishDelay time.Duration `mapstructure:"batchingMaxPublishDelay"` + BatchingMaxSize uint `mapstructure:"batchingMaxSize"` + BatchingMaxMessages uint `mapstructure:"batchingMaxMessages"` + Tenant string `mapstructure:"tenant"` + Namespace string `mapstructure:"namespace"` + Persistent bool `mapstructure:"persistent"` + RedeliveryDelay time.Duration `mapstructure:"redeliveryDelay"` + internalTopicSchemas map[string]schemaMetadata `mapstructure:"-"` + PublicKey string `mapstructure:"publicKey"` + PrivateKey string `mapstructure:"privateKey"` + Keys string `mapstructure:"keys"` + MaxConcurrentHandlers uint `mapstructure:"maxConcurrentHandlers"` + ReceiverQueueSize int `mapstructure:"receiverQueueSize"` + SubscriptionType string `mapstructure:"subscribeType"` + Token string `mapstructure:"token"` oauth2.ClientCredentialsMetadata `mapstructure:",squash"` } diff --git a/pubsub/pulsar/metadata.yaml b/pubsub/pulsar/metadata.yaml index 7cc216cf12..63421fe3c1 100644 --- a/pubsub/pulsar/metadata.yaml +++ b/pubsub/pulsar/metadata.yaml @@ -183,4 +183,13 @@ metadata: Sets the size of the consumer receive queue. Controls how many messages can be accumulated by the consumer before it is explicitly called to read messages by Dapr. default: '"1000"' - example: '"1000"' \ No newline at end of file + example: '"1000"' + - name: subscribeType + type: string + description: | + Pulsar supports four subscription types:"shared", "exclusive", "failover", "key_shared". + default: '"shared"' + example: '"exclusive"' + url: + title: "Pulsar Subscription Types" + url: "https://pulsar.apache.org/docs/3.0.x/concepts-messaging/#subscription-types" \ No newline at end of file diff --git a/pubsub/pulsar/pulsar.go b/pubsub/pulsar/pulsar.go index 7822d63f5e..39074509ce 100644 --- a/pubsub/pulsar/pulsar.go +++ b/pubsub/pulsar/pulsar.go @@ -138,6 +138,12 @@ func parsePulsarMetadata(meta pubsub.Metadata) (*pulsarMetadata, error) { return nil, errors.New("pulsar error: missing pulsar host") } + var err error + m.SubscriptionType, err = parseSubscriptionType(meta.Properties[subscribeTypeKey]) + if err != nil { + return nil, errors.New("invalid subscription type. Accepted values are `exclusive`, `shared`, `failover` and `key_shared`") + } + for k, v := range meta.Properties { switch { case strings.HasSuffix(k, topicJSONSchemaIdentifier): @@ -170,10 +176,8 @@ func (p *Pulsar) Init(ctx context.Context, metadata pubsub.Metadata) error { return err } pulsarURL := m.Host - if !strings.HasPrefix(m.Host, "http://") && - !strings.HasPrefix(m.Host, "https://") { - pulsarURL = fmt.Sprintf("%s%s", pulsarPrefix, m.Host) - } + + pulsarURL = sanitiseURL(pulsarURL) options := pulsar.ClientOptions{ URL: pulsarURL, OperationTimeout: 30 * time.Second, @@ -226,6 +230,23 @@ func (p *Pulsar) Init(ctx context.Context, metadata pubsub.Metadata) error { return nil } +func sanitiseURL(pulsarURL string) string { + prefixes := []string{"pulsar+ssl://", "pulsar://", "http://", "https://"} + + hasPrefix := false + for _, prefix := range prefixes { + if strings.HasPrefix(pulsarURL, prefix) { + hasPrefix = true + break + } + } + + if !hasPrefix { + pulsarURL = fmt.Sprintf("%s%s", pulsarPrefix, pulsarURL) + } + return pulsarURL +} + func (p *Pulsar) useProducerEncryption() bool { return p.metadata.PublicKey != "" && p.metadata.Keys != "" } @@ -370,11 +391,22 @@ func parsePublishMetadata(req *pubsub.PublishRequest, schema schemaMetadata) ( return msg, nil } -// default: shared -func getSubscribeType(metadata map[string]string) pulsar.SubscriptionType { +func parseSubscriptionType(in string) (string, error) { + subsType := strings.ToLower(in) + switch subsType { + case subscribeTypeExclusive, subscribeTypeFailover, subscribeTypeShared, subscribeTypeKeyShared: + return subsType, nil + case "": + return subscribeTypeShared, nil + default: + return "", fmt.Errorf("invalid subscription type: %s", subsType) + } +} + +// getSubscribeType doesn't do extra validations, because they were done in parseSubscriptionType. +func getSubscribeType(subsTypeStr string) pulsar.SubscriptionType { var subsType pulsar.SubscriptionType - subsTypeStr := strings.ToLower(metadata[subscribeTypeKey]) switch subsTypeStr { case subscribeTypeExclusive: subsType = pulsar.Exclusive @@ -384,8 +416,6 @@ func getSubscribeType(metadata map[string]string) pulsar.SubscriptionType { subsType = pulsar.Shared case subscribeTypeKeyShared: subsType = pulsar.KeyShared - default: - subsType = pulsar.Shared } return subsType @@ -400,15 +430,27 @@ func (p *Pulsar) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han topic := p.formatTopic(req.Topic) + subscribeType := p.metadata.SubscriptionType + if s, exists := req.Metadata[subscribeTypeKey]; exists { + subscribeType = s + } + options := pulsar.ConsumerOptions{ Topic: topic, SubscriptionName: p.metadata.ConsumerID, - Type: getSubscribeType(req.Metadata), + Type: getSubscribeType(subscribeType), MessageChannel: channel, NackRedeliveryDelay: p.metadata.RedeliveryDelay, ReceiverQueueSize: p.metadata.ReceiverQueueSize, } + // Handle KeySharedPolicy for key_shared subscription type + if options.Type == pulsar.KeyShared { + options.KeySharedPolicy = &pulsar.KeySharedPolicy{ + Mode: pulsar.KeySharedPolicyModeAutoSplit, + } + } + if p.useConsumerEncryption() { var reader crypto.KeyReader if isValidPEM(p.metadata.PublicKey) { @@ -430,6 +472,7 @@ func (p *Pulsar) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han p.logger.Debugf("Could not subscribe to %s, full topic name in pulsar is %s", req.Topic, topic) return err } + p.logger.Debugf("Subscribed to '%s'(%s) with type '%s'", req.Topic, topic, subscribeType) p.wg.Add(2) listenCtx, cancel := context.WithCancel(ctx) diff --git a/pubsub/pulsar/pulsar_test.go b/pubsub/pulsar/pulsar_test.go index 82b8157cd1..da1c62247e 100644 --- a/pubsub/pulsar/pulsar_test.go +++ b/pubsub/pulsar/pulsar_test.go @@ -48,6 +48,73 @@ func TestParsePulsarMetadata(t *testing.T) { assert.Equal(t, uint(200), meta.BatchingMaxMessages) assert.Equal(t, uint(333), meta.MaxConcurrentHandlers) assert.Empty(t, meta.internalTopicSchemas) + assert.Equal(t, "shared", meta.SubscriptionType) +} + +func TestParsePulsarMetadataSubscriptionType(t *testing.T) { + tt := []struct { + name string + subscribeType string + expected string + err bool + }{ + { + name: "test valid subscribe type - key_shared", + subscribeType: "key_shared", + expected: "key_shared", + err: false, + }, + { + name: "test valid subscribe type - shared", + subscribeType: "shared", + expected: "shared", + err: false, + }, + { + name: "test valid subscribe type - failover", + subscribeType: "failover", + expected: "failover", + err: false, + }, + { + name: "test valid subscribe type - exclusive", + subscribeType: "exclusive", + expected: "exclusive", + err: false, + }, + { + name: "test valid subscribe type - empty", + subscribeType: "", + expected: "shared", + err: false, + }, + { + name: "test invalid subscribe type", + subscribeType: "invalid", + err: true, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + m := pubsub.Metadata{} + + m.Properties = map[string]string{ + "host": "a", + "subscribeType": tc.subscribeType, + } + meta, err := parsePulsarMetadata(m) + + if tc.err { + require.Error(t, err) + assert.Nil(t, meta) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expected, meta.SubscriptionType) + }) + } } func TestParsePulsarSchemaMetadata(t *testing.T) { @@ -328,3 +395,27 @@ func TestEncryptionKeys(t *testing.T) { assert.False(t, r) }) } + +func TestSanitiseURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"With pulsar+ssl prefix", "pulsar+ssl://localhost:6650", "pulsar+ssl://localhost:6650"}, + {"With pulsar prefix", "pulsar://localhost:6650", "pulsar://localhost:6650"}, + {"With http prefix", "http://localhost:6650", "http://localhost:6650"}, + {"With https prefix", "https://localhost:6650", "https://localhost:6650"}, + {"Without prefix", "localhost:6650", "pulsar://localhost:6650"}, + {"Empty string", "", "pulsar://"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := sanitiseURL(test.input) + if actual != test.expected { + t.Errorf("sanitiseURL(%q) = %q; want %q", test.input, actual, test.expected) + } + }) + } +} diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 1f82031ba8..038399b30c 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -53,23 +52,33 @@ type ParameterStoreMetaData struct { } type ssmSecretStore struct { - client ssmiface.SSMAPI - prefix string - logger logger.Logger + authProvider awsAuth.Provider + prefix string + logger logger.Logger } // Init creates an AWS secret manager client. func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error { - meta, err := s.getSecretManagerMetadata(metadata) + m, err := s.getSecretManagerMetadata(metadata) if err != nil { return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - s.prefix = meta.Prefix + s.authProvider = provider + s.prefix = m.Prefix return nil } @@ -84,7 +93,7 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr name = fmt.Sprintf("%s:%s", req.Name, versionID) } - output, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + output, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -124,7 +133,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - output, err := s.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + output, err := s.authProvider.ParameterStore().Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -134,7 +143,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - params, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + params, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) @@ -155,15 +164,6 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul return resp, nil } -func (s *ssmSecretStore) getClient(metadata *ParameterStoreMetaData) (*ssm.SSM, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, err - } - - return ssm.New(sess), nil -} - func (s *ssmSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*ParameterStoreMetaData, error) { meta := ParameterStoreMetaData{} err := kitmd.DecodeMetadata(spec.Properties, &meta) @@ -182,5 +182,8 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { + if s.authProvider != nil { + return s.authProvider.Close() + } return nil } diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 8d9bcf6065..04c7a6995e 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -22,9 +22,11 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,20 +36,6 @@ import ( const secretValue = "secret" -type mockedSSM struct { - GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) - DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) - ssmiface.SSMAPI -} - -func (m *mockedSSM) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return m.GetParameterFn(ctx, input, option...) -} - -func (m *mockedSSM) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { - return m.DescribeParametersFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewParameterStore(logger.NewLogger("test")) @@ -68,21 +56,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("with valid path", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -93,25 +92,36 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - keys := strings.Split(*input.Name, ":") - assert.NotNil(t, keys) - assert.Len(t, keys, 2) - assert.Equalf(t, "1", keys[1], "Version IDs are same") - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: &keys[0], - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + keys := strings.Split(*input.Name, ":") + assert.NotNil(t, keys) + assert.Len(t, keys, 2) + assert.Equalf(t, "1", keys[1], "Version IDs are same") + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: &keys[0], + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{ @@ -124,21 +134,33 @@ func TestGetSecret(t *testing.T) { }) t.Run("with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) + secret := secretValue + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.GetSecretRequest{ @@ -152,13 +174,27 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -170,31 +206,42 @@ func TestGetSecret(t *testing.T) { func TestGetBulkSecrets(t *testing.T) { t.Run("successfully retrieve bulk secrets", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -205,30 +252,41 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("successfully retrieve bulk secrets with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/prefix/aws/dev/secret1"), - }, - { - Name: aws.String("/prefix/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/prefix/aws/dev/secret1"), + }, + { + Name: aws.String("/prefix/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.BulkGetSecretRequest{ @@ -241,23 +299,35 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on get parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") + }, + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -266,13 +336,25 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on describe parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 54ed329d35..979739be5b 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -20,7 +20,6 @@ import ( "reflect" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -49,8 +48,8 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - client secretsmanageriface.SecretsManagerAPI - logger logger.Logger + authProvider awsAuth.Provider + logger logger.Logger } // Init creates an AWS secret manager client. @@ -60,11 +59,20 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - + s.authProvider = provider return nil } @@ -78,8 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - - output, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -108,7 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.client.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -117,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -136,15 +143,6 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk return resp, nil } -func (s *smSecretStore) getClient(metadata *SecretManagerMetaData) (*secretsmanager.SecretsManager, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return secretsmanager.New(sess), nil -} - func (s *smSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*SecretManagerMetaData, error) { b, err := json.Marshal(spec.Properties) if err != nil { @@ -172,5 +170,8 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { + if s.authProvider != nil { + return s.authProvider.Close() + } return nil } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 85918237a3..7fbd8493af 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -21,25 +21,17 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/secretstores" "github.com/dapr/kit/logger" ) const secretValue = "secret" -type mockedSM struct { - GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) - secretsmanageriface.SecretsManagerAPI -} - -func (m *mockedSM) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return m.GetSecretValueFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewSecretManager(logger.NewLogger("test")) @@ -60,21 +52,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("without version id and version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.Nil(t, input.VersionId) - assert.Nil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.Nil(t, input.VersionId) + assert.Nil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, @@ -85,20 +88,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionId) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionId) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -111,20 +126,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -138,13 +165,26 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 503d7082c7..1ce4fd6de9 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -25,7 +25,6 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" jsoniterator "github.com/json-iterator/go" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -41,7 +40,8 @@ import ( type StateStore struct { state.BulkStore - client dynamodbiface.DynamoDBAPI + authProvider awsAuth.Provider + logger logger.Logger table string ttlAttributeName string partitionKey string @@ -53,6 +53,7 @@ type dynamoDBMetadata struct { SecretKey string `json:"secretKey" mapstructure:"secretKey" mdignore:"true"` SessionToken string `json:"sessionToken" mapstructure:"sessionToken" mdignore:"true"` + // TODO: rm the alias in Dapr 1.17 Region string `json:"region" mapstructure:"region" mapstructurealiases:"awsRegion" mdignore:"true"` Endpoint string `json:"endpoint"` Table string `json:"table"` @@ -66,9 +67,10 @@ const ( ) // NewDynamoDBStateStore returns a new dynamoDB state store. -func NewDynamoDBStateStore(_ logger.Logger) state.Store { +func NewDynamoDBStateStore(logger logger.Logger) state.Store { s := &StateStore{ partitionKey: defaultPartitionKeyName, + logger: logger, } s.BulkStore = state.NewDefaultBulkStore(s) return s @@ -80,14 +82,24 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { if err != nil { return err } - - // This check is needed because d.client is set to a mock in tests - if d.client == nil { - d.client, err = d.getClient(meta) + if d.authProvider == nil { + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + cfg := awsAuth.GetConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) if err != nil { return err } + d.authProvider = provider } + d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName d.partitionKey = meta.PartitionKey @@ -111,8 +123,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, }, } - - _, err := d.client.GetItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) return err } @@ -144,8 +155,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, }, } - - result, err := d.client.GetItemWithContext(ctx, input) + result, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -217,8 +227,7 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { condExpr := "attribute_not_exists(etag)" input.ConditionExpression = &condExpr } - - _, err = d.client.PutItemWithContext(ctx, input) + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -249,8 +258,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } input.ExpressionAttributeValues = exprAttrValues } - - _, err := d.client.DeleteItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -268,6 +276,9 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { + if d.authProvider != nil { + return d.authProvider.Close() + } return nil } @@ -281,16 +292,6 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata return &m, err } -func (d *StateStore) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // getItemFromReq converts a dapr state.SetRequest into an dynamodb item func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) { value, err := d.marshalToString(req.Value) @@ -431,8 +432,7 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat } twinput.TransactItems = append(twinput.TransactItems, twi) } - - _, err := d.client.TransactWriteItemsWithContext(ctx, twinput) + _, err := d.authProvider.DynamoDB().DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d1f98b70ba..7b667b6c78 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -21,26 +21,18 @@ import ( "testing" "time" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/state" ) -type mockedDynamoDB struct { - GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) - PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) - DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) - BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) - TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) - dynamodbiface.DynamoDBAPI -} - type DynamoDBItem struct { Key string `json:"key"` Value string `json:"value"` @@ -52,36 +44,28 @@ const ( pkey = "partitionKey" ) -func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return m.GetItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { - return m.PutItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { - return m.DeleteItemWithContextFn(ctx, input, op...) -} +func TestInit(t *testing.T) { + m := state.Metadata{} + mockedDB := &awsAuth.MockDynamoDB{ + // We're adding this so we can pass the connection check on Init + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return nil, nil + }, + } -func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { - return m.BatchWriteItemWithContextFn(ctx, input, op...) -} + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } -func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { - return m.TransactWriteItemsWithContextFn(ctx, input, op...) -} + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } -func TestInit(t *testing.T) { - m := state.Metadata{} - s := &StateStore{ + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, - client: &mockedDynamoDB{ - // We're adding this so we can pass the connection check on Init - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return nil, nil - }, - }, } t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) { @@ -132,16 +116,29 @@ func TestInit(t *testing.T) { }) t.Run("Init with bad table name or permissions", func(t *testing.T) { + table := "does-not-exist" m.Properties = map[string]string{ - "Table": "does-not-exist", - "Region": "eu-west-1", + "Table": table, } - s.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + table: table, + } err := s.Init(context.Background(), m) require.Error(t, err) @@ -151,10 +148,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -172,6 +166,20 @@ func TestGet(t *testing.T) { }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.GetRequest{ Key: "someKey", Metadata: nil, @@ -179,34 +187,46 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("4074862051"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), }, - }, nil - }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("4074862051"), + }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -216,7 +236,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -226,27 +246,39 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("35489251"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), + }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("35489251"), }, - }, nil - }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -256,20 +288,33 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return nil, errors.New("failed to retrieve data") - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return nil, errors.New("failed to retrieve data") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -277,20 +322,32 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.Error(t, err) assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{}, - }, nil - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -298,26 +355,38 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "value2": { - S: aws.String("value"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "value2": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -325,7 +394,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Empty(t, out.Data) assert.Nil(t, out.ETag) @@ -338,10 +407,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -360,21 +426,34 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -397,6 +476,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "1bdead4badc0ffee" req := &state.SetRequest{ ETag: &etag, @@ -405,15 +499,12 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -431,6 +522,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "bogusetag" req := &state.SetRequest{ ETag: &etag, @@ -440,7 +546,7 @@ func TestSet(t *testing.T) { }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -451,10 +557,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -474,6 +577,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -483,15 +601,12 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -506,6 +621,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -515,7 +645,7 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch err.(type) { case *state.ETagError: @@ -525,10 +655,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -547,7 +674,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + ttlAttributeName: "testAttributeName", + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", @@ -558,14 +700,11 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -584,7 +723,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + ttlAttributeName: "testAttributeName", + } req := &state.SetRequest{ Key: "someKey", @@ -595,33 +749,42 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -640,6 +803,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", Value: value{ @@ -649,34 +827,46 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("somekey"), - }, - "value": { - S: aws.String(`{"Value":"somevalue"}`), - }, - "ttlInSeconds": { - N: aws.String("180"), - }, - }, input.Item) + mockedDB := &awsAuth.MockDynamoDB{ + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("somekey"), + }, + "value": { + S: aws.String(`{"Value":"somevalue"}`), + }, + "ttlInSeconds": { + N: aws.String("180"), + }, + }, input.Item) - return &dynamodb.PutItemOutput{ - Attributes: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("value"), - }, + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.SetRequest{ @@ -688,7 +878,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -700,10 +890,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -715,7 +902,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -725,10 +927,8 @@ func TestDelete(t *testing.T) { ETag: &etag, Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -744,7 +944,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -755,10 +970,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -775,7 +987,21 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + err := s.Delete(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -786,26 +1012,36 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { - return nil, errors.New("unable to delete item") - }, + mockedDB := &awsAuth.MockDynamoDB{ + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { + return nil, errors.New("unable to delete item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(context.Background(), req) + err := s.Delete(context.Background(), req) require.Error(t, err) }) } func TestMultiTx(t *testing.T) { t.Run("Successfully Multiple Transaction Operations", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } firstKey := "key1" secondKey := "key2" secondValue := "value2" @@ -829,7 +1065,7 @@ func TestMultiTx(t *testing.T) { }, } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -853,13 +1089,28 @@ func TestMultiTx(t *testing.T) { return &dynamodb.TransactWriteItemsOutput{}, nil }, } - ss.table = tableName + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + table: tableName, + partitionKey: defaultPartitionKeyName, + } req := &state.TransactionalStateRequest{ Operations: ops, Metadata: map[string]string{}, } - err := ss.Multi(context.Background(), req) + err := s.Multi(context.Background(), req) require.NoError(t, err) }) } diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 1bd5472ef9..112043d637 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/dapr/components-contrib/contenttype" + "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -528,7 +530,18 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state var err error switch req := o.(type) { case state.SetRequest: - err = m.setInternal(sessCtx, &req) + { + isJSON := (len(req.Metadata) > 0 && req.Metadata[metadata.ContentType] == contenttype.JSONContentType) + if isJSON { + if bytes, ok := req.Value.([]byte); ok { + err = json.Unmarshal(bytes, &req.Value) + if err != nil { + break + } + } + } + err = m.setInternal(sessCtx, &req) + } case state.DeleteRequest: err = m.deleteInternal(sessCtx, &req) } diff --git a/state/postgresql/v1/metadata.yaml b/state/postgresql/v1/metadata.yaml index 51803bcbea..03387519ae 100644 --- a/state/postgresql/v1/metadata.yaml +++ b/state/postgresql/v1/metadata.yaml @@ -53,16 +53,12 @@ builtinAuthenticationProfiles: example: | "host=mydb.postgres.database.aws.com user=myapplication port=5432 dbname=dapr_test sslmode=require" type: string - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the AWS Relational Database Service is deployed to. - example: '"us-east-1"' - name: awsAccessKey type: string - required: true + required: false description: | + Deprecated as of Dapr 1.17. Use 'accessKey' instead if using AWS IAM. + If both fields are set, then 'accessKey' value will be used. AWS access key associated with an IAM account. example: '"AKIAIOSFODNN7EXAMPLE"' - name: awsSecretKey @@ -70,6 +66,8 @@ builtinAuthenticationProfiles: required: false sensitive: true description: | + Deprecated as of Dapr 1.17. Use 'secretKey' instead if using AWS IAM. + If both fields are set, then 'secretKey' value will be used. The secret key associated with the access key. example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' authenticationProfiles: diff --git a/state/postgresql/v2/metadata.go b/state/postgresql/v2/metadata.go index 7903d74a29..8b21fe9e44 100644 --- a/state/postgresql/v2/metadata.go +++ b/state/postgresql/v2/metadata.go @@ -43,7 +43,7 @@ type pgMetadata struct { Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"` CleanupInterval *time.Duration `mapstructure:"cleanupInterval" mapstructurealiases:"cleanupIntervalInSeconds"` - aws.AWSIAM `mapstructure:",squash"` + aws.DeprecatedPostgresIAM `mapstructure:",squash"` } func (m *pgMetadata) InitWithMetadata(meta state.Metadata, opts pgauth.InitWithMetadataOpts) error { @@ -60,7 +60,7 @@ func (m *pgMetadata) InitWithMetadata(meta state.Metadata, opts pgauth.InitWithM return err } - // Validate and sanitize input + // Validate and sanitize inputq err = m.PostgresAuthMetadata.InitWithMetadata(meta.Properties, opts) if err != nil { return err diff --git a/state/postgresql/v2/metadata.yaml b/state/postgresql/v2/metadata.yaml index ada800e4d4..de6103918d 100644 --- a/state/postgresql/v2/metadata.yaml +++ b/state/postgresql/v2/metadata.yaml @@ -52,16 +52,12 @@ builtinAuthenticationProfiles: example: | "host=mydb.postgres.database.aws.com user=myapplication port=5432 dbname=dapr_test sslmode=require" type: string - - name: awsRegion - type: string - required: true - description: | - The AWS Region where the AWS Relational Database Service is deployed to. - example: '"us-east-1"' - name: awsAccessKey type: string required: false description: | + Deprecated as of Dapr 1.17. Use 'accessKey' instead if using AWS IAM. + If both fields are set, then 'accessKey' value will be used. AWS access key associated with an IAM account. example: '"AKIAIOSFODNN7EXAMPLE"' - name: awsSecretKey @@ -69,6 +65,8 @@ builtinAuthenticationProfiles: required: false sensitive: true description: | + Deprecated as of Dapr 1.17. Use 'secretKey' instead if using AWS IAM. + If both fields are set, then 'secretKey' value will be used. The secret key associated with the access key. example: '"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' authenticationProfiles: diff --git a/state/postgresql/v2/postgresql.go b/state/postgresql/v2/postgresql.go index d323ca5c90..2b80567834 100644 --- a/state/postgresql/v2/postgresql.go +++ b/state/postgresql/v2/postgresql.go @@ -28,6 +28,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" pgauth "github.com/dapr/components-contrib/common/authentication/postgresql" pginterfaces "github.com/dapr/components-contrib/common/component/postgresql/interfaces" pgtransactions "github.com/dapr/components-contrib/common/component/postgresql/transactions" @@ -51,6 +52,8 @@ type PostgreSQL struct { enableAzureAD bool enableAWSIAM bool + + awsAuthProvider awsAuth.Provider } type Options struct { @@ -98,17 +101,32 @@ func (p *PostgreSQL) Init(ctx context.Context, meta state.Metadata) (err error) return err } + if opts.AWSIAMEnabled && p.metadata.UseAWSIAM { + opts, validateErr := p.metadata.BuildAwsIamOptions(p.logger, meta.Properties) + if validateErr != nil { + return fmt.Errorf("failed to validate AWS IAM authentication fields: %w", validateErr) + } + + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, *opts, awsAuth.GetConfig(*opts)) + if err != nil { + return err + } + p.awsAuthProvider = provider + p.awsAuthProvider.UpdatePostgres(ctx, config) + } + connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer connCancel() p.db, err = pgxpool.NewWithConfig(connCtx, config) - connCancel() if err != nil { err = fmt.Errorf("failed to connect to the database: %w", err) return err } pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer pingCancel() err = p.db.Ping(pingCtx) - pingCancel() if err != nil { err = fmt.Errorf("failed to ping the database: %w", err) return err @@ -534,11 +552,15 @@ func (p *PostgreSQL) Close() error { p.db = nil } + errs := make([]error, 2) if p.gc != nil { - return p.gc.Close() + errs[0] = p.gc.Close() } - return nil + if p.awsAuthProvider != nil { + errs[1] = p.awsAuthProvider.Close() + } + return errors.Join(errs...) } // GetCleanupInterval returns the cleanupInterval property. diff --git a/tests/certification/bindings/aws/s3/s3_test.go b/tests/certification/bindings/aws/s3/s3_test.go index 16e41abc49..c815901f8c 100644 --- a/tests/certification/bindings/aws/s3/s3_test.go +++ b/tests/certification/bindings/aws/s3/s3_test.go @@ -279,7 +279,7 @@ func S3SForcePathStyle(t *testing.T) { Step(sidecar.Run(sidecarName, append(componentRuntimeOptions(), embedded.WithoutApp(), - embedded.WithComponentsPath("./components/forcePathStyleTrue"), + embedded.WithResourcesPath("./components/forcePathStyleTrue"), embedded.WithDaprGRPCPort(strconv.Itoa(currentGRPCPort)), embedded.WithDaprHTTPPort(strconv.Itoa(currentHTTPPort)), )..., diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 3c5698de6b..b7540d32cc 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -84,21 +84,22 @@ require ( github.com/ardielle/ardielle-go v1.5.2 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect - github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.31.0 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.39 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.37 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 // indirect + github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a // indirect + github.com/aws/aws-sdk-go-v2 v1.32.4 // indirect + github.com/aws/aws-sdk-go-v2/config v1.28.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.43 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 // indirect github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect - github.com/aws/smithy-go v1.21.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.4 // indirect + github.com/aws/rolesanywhere-credential-helper v1.0.4 // indirect + github.com/aws/smithy-go v1.22.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.4.0 // indirect @@ -291,6 +292,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect + github.com/vmware/vmware-go-kcl v1.5.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect @@ -334,6 +336,7 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.30.2 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 145f05a305..1055475a16 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -192,51 +192,53 @@ github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:o github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= -github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0 h1:UyjtGmO0Uwl/K+zpzPwLoXzMhcN9xmnR2nrqJoBrg3c= -github.com/aws/aws-msk-iam-sasl-signer-go v1.0.0/go.mod h1:TJAXuFs2HcMib3sN5L0gUC+Q01Qvy3DemvA55WuC+iA= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a h1:QFemvMGPnajaeRBkFc1HoEA7qzVjUv+rkYb1/ps1/UE= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.1-0.20241125194140-078c08b8574a/go.mod h1:MVYeeOhILFFemC/XlYTClvBjYZrg/EPd3ts885KrNTI= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.50.19 h1:YSIDKRSkh/TW0RPWoocdLqtC/T5W6IGBVhFs6P7Qcac= github.com/aws/aws-sdk-go v1.50.19/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.9.2/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= -github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U= -github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA= +github.com/aws/aws-sdk-go-v2 v1.32.4 h1:S13INUiTxgrPueTmrm5DZ+MiAo99zYzHEFh1UNkOxNE= +github.com/aws/aws-sdk-go-v2 v1.32.4/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/config v1.8.3/go.mod h1:4AEiLtAb8kLs7vgw2ZV3p2VZ1+hBavOc84hqxVNpCyw= -github.com/aws/aws-sdk-go-v2/config v1.27.39 h1:FCylu78eTGzW1ynHcongXK9YHtoXD5AiiUqq3YfJYjU= -github.com/aws/aws-sdk-go-v2/config v1.27.39/go.mod h1:wczj2hbyskP4LjMKBEZwPRO1shXY+GsQleab+ZXT2ik= +github.com/aws/aws-sdk-go-v2/config v1.28.2 h1:FLvWA97elBiSPdIol4CXfIAY1wlq3KzoSgkMuZSuSe8= +github.com/aws/aws-sdk-go-v2/config v1.28.2/go.mod h1:hNmQsKfUqpKz2yfnZUB60GCemPmeqAalVTui0gOxjAE= github.com/aws/aws-sdk-go-v2/credentials v1.4.3/go.mod h1:FNNC6nQZQUuyhq5aE5c7ata8o9e4ECGmS4lAXC7o1mQ= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37 h1:G2aOH01yW8X373JK419THj5QVqu9vKEwxSEsGxihoW0= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37/go.mod h1:0ecCjlb7htYCptRD45lXJ6aJDQac6D2NlKGpZqyTG6A= +github.com/aws/aws-sdk-go-v2/credentials v1.17.43 h1:SEGdVOOE1Wyr2XFKQopQ5GYjym3nYHcphesdt78rNkY= +github.com/aws/aws-sdk-go-v2/credentials v1.17.43/go.mod h1:3aiza5kSyAE4eujSanOkSkAmX/RnVqslM+GRQ/Xvv4c= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.6.0/go.mod h1:gqlclDEZp4aqJOancXK6TN24aKhT0W0Ae9MHk3wzTMM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 h1:C/d03NAmh8C4BZXhuRNboF/DqhBkBCeDiJDcaqIT5pA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14/go.mod h1:7I0Ju7p9mCIdlrfS+JCgqcYD0VXz/N4yozsox+0o078= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 h1:woXadbf0c7enQ2UGCi8gW/WuKmE0xIzxBF/eD94jMKQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19/go.mod h1:zminj5ucw7w0r65bP6nhyOd3xL6veAUMc3ElGMoLVb4= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 h1:z6fAXB4HSuYjrE/P8RU3NdCaN+EPaeq/+80aisCjuF8= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10/go.mod h1:PoPjOi7j+/DtKIGC58HRfcdWKBPYYXwdKnRG+po+hzo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 h1:A2w6m6Tmr+BNXjDsr7M90zkWjsu4JXHwrzPg235STs4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23/go.mod h1:35EVp9wyeANdujZruvHiQUAo9E3vbhnIO1mTCAxMlY0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 h1:pgYW9FCabt2M25MoHYCfMrVY2ghiiBKYWUVXfwZs+sU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23/go.mod h1:c48kLgzO19wAu3CPkDWC28JbaJ+hfQlsdl7I2+oqIbk= github.com/aws/aws-sdk-go-v2/internal/ini v1.2.4/go.mod h1:ZcBrrI3zBKlhGFNYWvju0I3TR93I7YIgAfy82Fh4lcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/appconfig v1.4.2/go.mod h1:FZ3HkCe+b10uFZZkFdvf98LHW21k49W8o8J366lqVKY= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 h1:QFASJGfT8wMXtuP3D5CRmMjARHv9ZmzFUMJznHDOY3w= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5/go.mod h1:QdZ3OmoIjSX+8D1OPAzPxDfjXASbBMDsz9qvtyIhtik= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.2/go.mod h1:72HRZDLMtmVQiLG2tLfQcaWLCssELvGl+Zf2WVxMmR8= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 h1:Xbwbmk44URTiHNx6PNo0ujDE6ERlsCKJD3u1zfnzAPg= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20/go.mod h1:oAfOFzUB14ltPZj1rWwRc3d/6OgD76R8KlvU3EqM9Fg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 h1:tHxQi/XHPK0ctd/wdOw0t7Xrc2OxcRCnVzv8lwWPu0c= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4/go.mod h1:4GQbF1vJzG60poZqWatZlhP31y8PGCCVTvIGPdaaYJ0= github.com/aws/aws-sdk-go-v2/service/sso v1.4.2/go.mod h1:NBvT9R1MEF+Ud6ApJKM0G+IkPchKS7p7c2YPKwHmBOk= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 h1:rs4JCczF805+FDv2tRhZ1NU0RB2H6ryAvsWPanAr72Y= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3/go.mod h1:XRlMvmad0ZNL+75C5FYdMvbbLkd6qiqz6foR1nA1PXY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 h1:S7EPdMVZod8BGKQQPTBK+FcX9g7bKR7c4+HxWqHP7Vg= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDXIzAB9GAwVSzFzSy97uZ3IsHo4E= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.4 h1:BqE3NRG6bsODh++VMKMsDmFuJTHrdD4rJZqHjDeF6XI= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.4/go.mod h1:wrMCEwjFPms+V86TCQQeOxQF/If4vT44FGIOFiMC2ck= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4 h1:zcx9LiGWZ6i6pjdcoE9oXAB6mUdeyC36Ia/QEiIvYdg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.4/go.mod h1:Tp/ly1cTjRLGBBmNccFumbZ8oqpZlpdhFf80SrRh4is= github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.4 h1:yDxvkz3/uOKfxnv8YhzOi9m+2OGIxF+on3KOISbK5IU= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.4/go.mod h1:9XEUty5v5UAsMiFOBJrNibZgwCeOma73jgGwwhgffa8= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= -github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= +github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= @@ -1381,6 +1383,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/vmware/vmware-go-kcl v1.5.1 h1:1rJLfAX4sDnCyatNoD/WJzVafkwST6u/cgY/Uf2VgHk= +github.com/vmware/vmware-go-kcl v1.5.1/go.mod h1:kXJmQ6h0dRMRrp1uWU9XbIXvwelDpTxSPquvQUBdpbo= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 25de650cf9..cfb94dbfbb 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -15,6 +15,7 @@ package state import ( "context" + "encoding/base64" "encoding/json" "fmt" "slices" @@ -784,6 +785,70 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.Equal(t, v, res.Data) } }) + + t.Run("transaction-serialization-grpc-json", func(t *testing.T) { + features := statestore.Features() + // this check for exclude redis 7 + if state.FeatureQueryAPI.IsPresent(features) { + json := "{\"id\":1223,\"name\":\"test\"}" + keyTest1 := key + "-key-grpc" + valueTest := []byte(json) + keyTest2 := key + "-key-grpc-no-json" + + metadataTest1 := map[string]string{ + "contentType": "application/json", + } + + operations := []state.TransactionalStateOperation{ + state.SetRequest{ + Key: keyTest1, + Value: valueTest, + Metadata: metadataTest1, + }, + state.SetRequest{ + Key: keyTest2, + Value: valueTest, + }, + } + + expected := map[string][]byte{ + keyTest1: []byte(json), + keyTest2: []byte(json), + } + + expectedMetadata := map[string]map[string]string{ + keyTest1: metadataTest1, + } + + // Act + transactionStore, ok := statestore.(state.TransactionalStore) + assert.True(t, ok) + err := transactionStore.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + require.NoError(t, err) + + // Assert + for k, v := range expected { + res, err := statestore.Get(context.Background(), &state.GetRequest{ + Key: k, + Metadata: expectedMetadata[k], + }) + expectedValue := res.Data + + // In redisjson when set the value with contentType = application/Json store the value in base64 + if strings.HasPrefix(string(expectedValue), "\"ey") { + valueBase64 := strings.Trim(string(expectedValue), "\"") + expectedValueDecoded, _ := base64.StdEncoding.DecodeString(valueBase64) + require.NoError(t, err) + assert.Equal(t, expectedValueDecoded, v) + } else { + require.NoError(t, err) + assert.Equal(t, expectedValue, v) + } + } + } + }) } else { t.Run("component does not implement TransactionalStore interface", func(t *testing.T) { _, ok := statestore.(state.TransactionalStore) diff --git a/tests/conformance/workflows/workflows.go b/tests/conformance/workflows/workflows.go index 7e8e3310af..180ef1cf77 100644 --- a/tests/conformance/workflows/workflows.go +++ b/tests/conformance/workflows/workflows.go @@ -15,12 +15,12 @@ package workflows import ( "context" - "encoding/json" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" "github.com/dapr/kit/logger" @@ -61,14 +61,12 @@ func ConformanceTests(t *testing.T, props map[string]string, workflowItem workfl t.Run("start", func(t *testing.T) { testLogger.Info("Start test running...") - inputBytes, _ := json.Marshal(10) // Time that the activity within the workflow runs for - testInstanceID := "TestID" t.Run("start", func(t *testing.T) { req := &workflows.StartRequest{ - InstanceID: testInstanceID, + InstanceID: &testInstanceID, WorkflowName: "TestWorkflow", - WorkflowInput: inputBytes, + WorkflowInput: wrapperspb.String("10"), Options: map[string]string{ "task_queue": "TestTaskQueue", }, diff --git a/workflows/requests.go b/workflows/requests.go index 7f4019c729..fcc7a9e6c1 100644 --- a/workflows/requests.go +++ b/workflows/requests.go @@ -1,11 +1,13 @@ package workflows +import "google.golang.org/protobuf/types/known/wrapperspb" + // StartRequest is the struct describing a start workflow request. type StartRequest struct { - InstanceID string `json:"instanceID"` - Options map[string]string `json:"options"` - WorkflowName string `json:"workflowName"` - WorkflowInput []byte `json:"workflowInput"` + InstanceID *string `json:"instanceID"` + Options map[string]string `json:"options"` + WorkflowName string `json:"workflowName"` + WorkflowInput *wrapperspb.StringValue `json:"workflowInput"` } // GetRequest is the struct describing a get workflow state request. @@ -16,14 +18,14 @@ type GetRequest struct { // TerminateRequest is the struct describing a terminate workflow request. type TerminateRequest struct { InstanceID string `json:"instanceID"` - Recursive bool `json:"recursive"` + Recursive *bool `json:"recursive"` } // RaiseEventRequest is the struct describing a raise workflow event request. type RaiseEventRequest struct { - InstanceID string `json:"instanceID"` - EventName string `json:"name"` - EventData []byte `json:"data"` + InstanceID string `json:"instanceID"` + EventName string `json:"name"` + EventData *wrapperspb.StringValue `json:"data"` } // PauseRequest is the struct describing a pause workflow request. @@ -39,5 +41,5 @@ type ResumeRequest struct { // PurgeRequest is the object describing a Purge request. type PurgeRequest struct { InstanceID string `json:"instanceID"` - Recursive bool `json:"recursive"` + Recursive *bool `json:"recursive"` }