From 901420e34d4b865ac2a0c8257fe7fb2523282b33 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 10 Dec 2024 09:52:00 -0600 Subject: [PATCH 1/4] fix(snssqs): if topic dne and we do not create, then we skip the actual err Signed-off-by: Samantha Coyle --- common/authentication/aws/client.go | 7 +++++++ common/authentication/aws/static.go | 3 +++ pubsub/aws/snssqs/metadata.go | 10 ---------- pubsub/aws/snssqs/snssqs.go | 27 ++++++++++++++++++++------- 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 11b26e4988..5cb4bf1fd8 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -111,6 +111,12 @@ type SnsSqsClients struct { Sns *sns.SNS Sqs *sqs.SQS Sts *sts.STS + + region string +} + +func (c *SnsSqsClients) Region() string { + return c.region } type SnsClients struct { @@ -168,6 +174,7 @@ func (c *SnsSqsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) c.Sqs = sqs.New(session, session.Config) c.Sts = sts.New(session, session.Config) + c.region = *session.Config.Region } func (c *SqsClients) New(session *session.Session) { diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index e79dee1841..5fa887c474 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -250,6 +250,9 @@ func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Con if err != nil { return fmt.Errorf("failed to get database token: %w", err) } + if pwd == "" { + return errors.New("failed to get a valid password for the database token for AWS IAM authentication") + } pgConfig.Password = pwd poolConfig.ConnConfig.Password = pwd diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index f79ed207af..ff1d0402ad 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -6,8 +6,6 @@ import ( "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/metadata" - - "github.com/aws/aws-sdk-go/aws/endpoints" ) type snsSqsMetadata struct { @@ -84,14 +82,6 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error return nil, err } - if md.Region != "" { - if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), md.Region); ok { - md.internalPartition = partition.ID() - } else { - md.internalPartition = "aws" - } - } - if md.SqsQueueName == "" { return nil, errors.New("consumerID must be set") } diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index e84dbeb80a..548af153da 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -33,12 +33,12 @@ import ( "github.com/dapr/kit/retry" - gonanoid "github.com/matoous/go-nanoid/v2" - + "github.com/aws/aws-sdk-go/aws/endpoints" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" + gonanoid "github.com/matoous/go-nanoid/v2" ) type snsSqs struct { @@ -162,6 +162,14 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { s.authProvider = provider } + if s.authProvider.SnsSqs().Region() != "" { + if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), s.authProvider.SnsSqs().Region()); ok { + m.internalPartition = partition.ID() + } else { + m.internalPartition = "aws" + } + } + s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) @@ -201,7 +209,7 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { } func (s *snsSqs) buildARN(serviceName, entityName string) string { - return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName) + return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.authProvider.SnsSqs().Region(), s.metadata.AccountID, entityName) } func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) { @@ -257,9 +265,8 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s } // creating queues is idempotent, the names serve as unique keys among a given region. - s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic) - if !s.metadata.DisableEntityManagement { + s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic) topicArn, err = s.createTopic(ctx, sanitizedTopic) if err != nil { err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err) @@ -267,10 +274,14 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s return topicArn, sanitizedTopic, err } } else { + s.logger.Debugf("No SNS topic ARN found for topic: %s. Checking AWS SNS for if (sanitized) topic exists: %s", topic, sanitizedTopic) topicArn, err = s.getTopicArn(ctx, sanitizedTopic) if err != nil { + var awsErr awserr.Error + if errors.As(err, &awsErr) && awsErr.Code() == sns.ErrCodeNotFoundException { + return topicArn, sanitizedTopic, errors.New("topic not found") + } err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err) - return topicArn, sanitizedTopic, err } } @@ -855,7 +866,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic) if err != nil { - s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err) + wrappedErr := fmt.Errorf("error getting topic ARN for %s: %v", req.Topic, err) + s.logger.Error(wrappedErr) + return wrappedErr } message := string(req.Data) From 8531353d7561e0f7e4287a1d35b04e95d8517b02 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 10 Dec 2024 10:07:00 -0600 Subject: [PATCH 2/4] style: make linter happy Signed-off-by: Samantha Coyle --- pubsub/aws/snssqs/snssqs.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 548af153da..021fd0249a 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -34,11 +34,12 @@ import ( "github.com/dapr/kit/retry" "github.com/aws/aws-sdk-go/aws/endpoints" + gonanoid "github.com/matoous/go-nanoid/v2" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" - gonanoid "github.com/matoous/go-nanoid/v2" ) type snsSqs struct { From 6aae517089be72e1df78c229412f52536c5a5847 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 10 Dec 2024 11:28:18 -0600 Subject: [PATCH 3/4] fix(tests): update client for tests Signed-off-by: Samantha Coyle --- common/authentication/aws/client.go | 15 +++++++++++---- common/authentication/aws/static.go | 10 +++++----- common/authentication/aws/x509.go | 10 +++++----- pubsub/aws/snssqs/metadata.go | 5 +++++ pubsub/aws/snssqs/snssqs_test.go | 13 ++++++++++++- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go index 5cb4bf1fd8..c7829d7966 100644 --- a/common/authentication/aws/client.go +++ b/common/authentication/aws/client.go @@ -51,7 +51,7 @@ type Clients struct { Dynamo *DynamoDBClients sns *SnsClients sqs *SqsClients - snssqs *SnsSqsClients + Snssqs *SnsSqsClients Secret *SecretManagerClients ParameterStore *ParameterStoreClients kinesis *KinesisClients @@ -75,8 +75,8 @@ func (c *Clients) refresh(session *session.Session) error { c.sns.New(session) case c.sqs != nil: c.sqs.New(session) - case c.snssqs != nil: - c.snssqs.New(session) + case c.Snssqs != nil: + c.Snssqs.New(session) case c.Secret != nil: c.Secret.New(session) case c.ParameterStore != nil: @@ -170,11 +170,18 @@ func (c *SnsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) } +func (c *SnsSqsClients) SetRegion(region string) { + c.region = region +} + func (c *SnsSqsClients) New(session *session.Session) { c.Sns = sns.New(session, session.Config) c.Sqs = sqs.New(session, session.Config) c.Sts = sts.New(session, session.Config) - c.region = *session.Config.Region + // the empty string check is added to allow for this to be injected by tests + if c.region != "" { + c.SetRegion(*session.Config.Region) + } } func (c *SqsClients) New(session *session.Session) { diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go index 5fa887c474..61416e7e5d 100644 --- a/common/authentication/aws/static.go +++ b/common/authentication/aws/static.go @@ -168,14 +168,14 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.clients.snssqs != nil { - return a.clients.snssqs + if a.clients.Snssqs != nil { + return a.clients.Snssqs } clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - return a.clients.snssqs + a.clients.Snssqs = &clients + a.clients.Snssqs.New(a.session) + return a.clients.Snssqs } func (a *StaticAuth) SecretManager() *SecretManagerClients { diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index 6556ece74c..299605b24c 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -242,14 +242,14 @@ func (a *x509) SnsSqs() *SnsSqsClients { a.mu.Lock() defer a.mu.Unlock() - if a.clients.snssqs != nil { - return a.clients.snssqs + if a.clients.Snssqs != nil { + return a.clients.Snssqs } clients := SnsSqsClients{} - a.clients.snssqs = &clients - a.clients.snssqs.New(a.session) - return a.clients.snssqs + a.clients.Snssqs = &clients + a.clients.Snssqs.New(a.session) + return a.clients.Snssqs } func (a *x509) SecretManager() *SecretManagerClients { diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index ff1d0402ad..7064665767 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -82,6 +82,11 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error return nil, err } + // set an initial default that we will override if an aws region is found after we init the snssqs via the aws auth provider in Init() + if md.internalPartition == "" { + md.internalPartition = "aws" + } + if md.SqsQueueName == "" { return nil, errors.New("consumerID must be set") } diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 82f58bcbff..0f1c9dd73f 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" @@ -445,8 +447,17 @@ func Test_buildARN_DefaultPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("r") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ From 63125970e5afd311a482dfb011c85f7f70a2b7f6 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Tue, 10 Dec 2024 14:01:19 -0600 Subject: [PATCH 4/4] fix(tests): fix last two Signed-off-by: Samantha Coyle --- pubsub/aws/snssqs/metadata.go | 5 ----- pubsub/aws/snssqs/snssqs.go | 6 ++---- pubsub/aws/snssqs/snssqs_test.go | 31 +++++++++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index 7064665767..ff1d0402ad 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -82,11 +82,6 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error return nil, err } - // set an initial default that we will override if an aws region is found after we init the snssqs via the aws auth provider in Init() - if md.internalPartition == "" { - md.internalPartition = "aws" - } - if md.SqsQueueName == "" { return nil, errors.New("consumerID must be set") } diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 021fd0249a..7578f46160 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -141,7 +141,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { if err != nil { return err } - s.metadata = m if s.authProvider == nil { @@ -165,12 +164,11 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { if s.authProvider.SnsSqs().Region() != "" { if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), s.authProvider.SnsSqs().Region()); ok { - m.internalPartition = partition.ID() + s.metadata.internalPartition = partition.ID() } else { - m.internalPartition = "aws" + s.metadata.internalPartition = "aws" } } - s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 0f1c9dd73f..2f97320cb4 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -470,6 +470,9 @@ func Test_buildARN_DefaultPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws:sns:r:123456789012:myTopic", arn) } @@ -479,8 +482,17 @@ func Test_buildARN_StandardPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("us-west-2") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ @@ -493,6 +505,9 @@ func Test_buildARN_StandardPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws:sns:us-west-2:123456789012:myTopic", arn) } @@ -502,8 +517,17 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { r := require.New(t) l := logger.NewLogger("SnsSqs unit test") l.SetOutputLevel(logger.DebugLevel) + mockAuthProvider := &awsAuth.StaticAuth{} + mockedSnssqs := &awsAuth.SnsSqsClients{} + mockedSnssqs.SetRegion("cn-northwest-1") + mockedClients := awsAuth.Clients{ + Snssqs: mockedSnssqs, + } + mockAuthProvider.WithMockClients(&mockedClients) + ps := snsSqs{ - logger: l, + logger: l, + authProvider: mockAuthProvider, } md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ @@ -516,6 +540,9 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { md.AccountID = "123456789012" ps.metadata = md + // This is now set in the Init func, so we must set it in test to have the value to build the arn with + ps.metadata.internalPartition = "aws-cn" + arn := ps.buildARN("sns", "myTopic") r.Equal("arn:aws-cn:sns:cn-northwest-1:123456789012:myTopic", arn) }