diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 11b26e4988..c7829d7966 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -51,7 +51,7 @@ type Clients struct { Dynamo *DynamoDBClients sns *SnsClients sqs *SqsClients - snssqs *SnsSqsClients + Snssqs *SnsSqsClients Secret *SecretManagerClients ParameterStore *ParameterStoreClients kinesis *KinesisClients @@ -75,8 +75,8 @@ func (c *Clients) refresh(session *session.Session) error { c.sns.New(session) case c.sqs != nil: c.sqs.New(session) - case c.snssqs != nil: - c.snssqs.New(session) + case c.Snssqs != nil: + c.Snssqs.New(session) case c.Secret != nil: c.Secret.New(session) case c.ParameterStore != nil: @@ -111,6 +111,12 @@ type SnsSqsClients struct { Sns *sns.SNS Sqs *sqs.SQS Sts *sts.STS + + region string +} + +func (c *SnsSqsClients) Region() string { + return c.region } type SnsClients struct { @@ -164,10 +170,18 @@ func (c *SnsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) } +func (c *SnsSqsClients) SetRegion(region string) { + c.region = region +} + func (c *SnsSqsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) c.Sqs = sqs.New(session, session.Config) c.Sts = sts.New(session, session.Config) + // the empty string check is added to allow for this to be injected by tests + if c.region != "" { + c.SetRegion(*session.Config.Region) + } } func (c *SqsClients) New(session *session.Session) { diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index e79dee1841..61416e7e5d 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -168,14 +168,14 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.clients.snssqs != nil { - return a.clients.snssqs + if a.clients.Snssqs != nil { + return a.clients.Snssqs } clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - return a.clients.snssqs + a.clients.Snssqs = &clients + a.clients.Snssqs.New(a.session) + return a.clients.Snssqs } func (a *StaticAuth) SecretManager() *SecretManagerClients { @@ -250,6 +250,9 @@ func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Con if err != nil { return fmt.Errorf("failed to get database token: %w", err) } + if pwd == "" { + return errors.New("failed to get a valid password for the database token for AWS IAM authentication") + } pgConfig.Password = pwd poolConfig.ConnConfig.Password = pwd diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 6556ece74c..299605b24c 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -242,14 +242,14 @@ func (a *x509) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.clients.snssqs != nil { - return a.clients.snssqs + if a.clients.Snssqs != nil { + return a.clients.Snssqs } clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - return a.clients.snssqs + a.clients.Snssqs = &clients + a.clients.Snssqs.New(a.session) + return a.clients.Snssqs } func (a *x509) SecretManager() *SecretManagerClients { diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index f79ed207af..ff1d0402ad 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -6,8 +6,6 @@ import ( "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/metadata" - - "github.com/aws/aws-sdk-go/aws/endpoints" ) type snsSqsMetadata struct { @@ -84,14 +82,6 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error return nil, err } - if md.Region != "" { - if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), md.Region); ok { - md.internalPartition = partition.ID() - } else { - md.internalPartition = "aws" - } - } - if md.SqsQueueName == "" { return nil, errors.New("consumerID must be set") } diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index e84dbeb80a..7578f46160 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -33,6 +33,7 @@ import ( "github.com/dapr/kit/retry" + "github.com/aws/aws-sdk-go/aws/endpoints" gonanoid "github.com/matoous/go-nanoid/v2" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -140,7 +141,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { if err != nil { return err } - s.metadata = m if s.authProvider == nil { @@ -162,6 +162,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { s.authProvider = provider } + if s.authProvider.SnsSqs().Region() != "" { + if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), s.authProvider.SnsSqs().Region()); ok { + s.metadata.internalPartition = partition.ID() + } else { + s.metadata.internalPartition = "aws" + } + } s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) @@ -201,7 +208,7 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { } func (s *snsSqs) buildARN(serviceName, entityName string) string { - return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName) + return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.authProvider.SnsSqs().Region(), s.metadata.AccountID, entityName) } func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) { @@ -257,9 +264,8 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s } // creating queues is idempotent, the names serve as unique keys among a given region. - s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic) - if !s.metadata.DisableEntityManagement { + s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic) topicArn, err = s.createTopic(ctx, sanitizedTopic) if err != nil { err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err) @@ -267,10 +273,14 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s return topicArn, sanitizedTopic, err } } else { + s.logger.Debugf("No SNS topic ARN found for topic: %s. Checking AWS SNS for if (sanitized) topic exists: %s", topic, sanitizedTopic) topicArn, err = s.getTopicArn(ctx, sanitizedTopic) if err != nil { + var awsErr awserr.Error + if errors.As(err, &awsErr) && awsErr.Code() == sns.ErrCodeNotFoundException { + return topicArn, sanitizedTopic, errors.New("topic not found") + } err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err) - return topicArn, sanitizedTopic, err } } @@ -855,7 +865,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic) if err != nil { - s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err) + wrappedErr := fmt.Errorf("error getting topic ARN for %s: %v", req.Topic, err) + s.logger.Error(wrappedErr) + return wrappedErr } message := string(req.Data) diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 82f58bcbff..2f97320cb4 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" @@ -445,8 +447,17 @@ func Test_buildARN_DefaultPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("r") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ @@ -459,6 +470,9 @@ func Test_buildARN_DefaultPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws:sns:r:123456789012:myTopic", arn) } @@ -468,8 +482,17 @@ func Test_buildARN_StandardPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("us-west-2") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ @@ -482,6 +505,9 @@ func Test_buildARN_StandardPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws:sns:us-west-2:123456789012:myTopic", arn) } @@ -491,8 +517,17 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("cn-northwest-1") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ @@ -505,6 +540,9 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws-cn" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws-cn:sns:cn-northwest-1:123456789012:myTopic", arn) }