Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(snssqs): if topic dne + we dont create, then we skip the actual err #3630

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions common/authentication/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Clients struct {
Dynamo *DynamoDBClients
sns *SnsClients
sqs *SqsClients
snssqs *SnsSqsClients
Snssqs *SnsSqsClients
Secret *SecretManagerClients
ParameterStore *ParameterStoreClients
kinesis *KinesisClients
Expand All @@ -75,8 +75,8 @@ func (c *Clients) refresh(session *session.Session) error {
c.sns.New(session)
case c.sqs != nil:
c.sqs.New(session)
case c.snssqs != nil:
c.snssqs.New(session)
case c.Snssqs != nil:
c.Snssqs.New(session)
case c.Secret != nil:
c.Secret.New(session)
case c.ParameterStore != nil:
Expand Down Expand Up @@ -111,6 +111,12 @@ type SnsSqsClients struct {
Sns *sns.SNS
Sqs *sqs.SQS
Sts *sts.STS

region string
}

func (c *SnsSqsClients) Region() string {
return c.region
}

type SnsClients struct {
Expand Down Expand Up @@ -164,10 +170,18 @@ func (c *SnsClients) New(session *session.Session) {
c.Sns = sns.New(session, session.Config)
}

func (c *SnsSqsClients) SetRegion(region string) {
c.region = region
}

func (c *SnsSqsClients) New(session *session.Session) {
c.Sns = sns.New(session, session.Config)
c.Sqs = sqs.New(session, session.Config)
c.Sts = sts.New(session, session.Config)
// the empty string check is added to allow for this to be injected by tests
if c.region != "" {
c.SetRegion(*session.Config.Region)
}
}

func (c *SqsClients) New(session *session.Session) {
Expand Down
13 changes: 8 additions & 5 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.clients.snssqs != nil {
return a.clients.snssqs
if a.clients.Snssqs != nil {
return a.clients.Snssqs
}

clients := SnsSqsClients{}
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
a.clients.Snssqs = &clients
a.clients.Snssqs.New(a.session)
return a.clients.Snssqs
}

func (a *StaticAuth) SecretManager() *SecretManagerClients {
Expand Down Expand Up @@ -250,6 +250,9 @@ func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Con
if err != nil {
return fmt.Errorf("failed to get database token: %w", err)
}
if pwd == "" {
return errors.New("failed to get a valid password for the database token for AWS IAM authentication")
}
pgConfig.Password = pwd
poolConfig.ConnConfig.Password = pwd

Expand Down
10 changes: 5 additions & 5 deletions common/authentication/aws/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ func (a *x509) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.clients.snssqs != nil {
return a.clients.snssqs
if a.clients.Snssqs != nil {
return a.clients.Snssqs
}

clients := SnsSqsClients{}
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
a.clients.Snssqs = &clients
a.clients.Snssqs.New(a.session)
return a.clients.Snssqs
}

func (a *x509) SecretManager() *SecretManagerClients {
Expand Down
10 changes: 0 additions & 10 deletions pubsub/aws/snssqs/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (

"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/metadata"

"github.com/aws/aws-sdk-go/aws/endpoints"
)

type snsSqsMetadata struct {
Expand Down Expand Up @@ -84,14 +82,6 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
return nil, err
}

if md.Region != "" {
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), md.Region); ok {
md.internalPartition = partition.ID()
} else {
md.internalPartition = "aws"
}
}

if md.SqsQueueName == "" {
return nil, errors.New("consumerID must be set")
}
Expand Down
24 changes: 18 additions & 6 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

"github.com/dapr/kit/retry"

"github.com/aws/aws-sdk-go/aws/endpoints"
gonanoid "github.com/matoous/go-nanoid/v2"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
Expand Down Expand Up @@ -140,7 +141,6 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
if err != nil {
return err
}

s.metadata = m

if s.authProvider == nil {
Expand All @@ -162,6 +162,13 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
s.authProvider = provider
}

if s.authProvider.SnsSqs().Region() != "" {
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), s.authProvider.SnsSqs().Region()); ok {
s.metadata.internalPartition = partition.ID()
} else {
s.metadata.internalPartition = "aws"
}
}
s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second))

err = s.setAwsAccountIDIfNotProvided(ctx)
Expand Down Expand Up @@ -201,7 +208,7 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error {
}

func (s *snsSqs) buildARN(serviceName, entityName string) string {
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName)
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.authProvider.SnsSqs().Region(), s.metadata.AccountID, entityName)
}

func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) {
Expand Down Expand Up @@ -257,20 +264,23 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s
}

// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)

if !s.metadata.DisableEntityManagement {
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
topicArn, err = s.createTopic(ctx, sanitizedTopic)
if err != nil {
err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)

return topicArn, sanitizedTopic, err
}
} else {
s.logger.Debugf("No SNS topic ARN found for topic: %s. Checking AWS SNS for if (sanitized) topic exists: %s", topic, sanitizedTopic)
topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
if err != nil {
var awsErr awserr.Error
if errors.As(err, &awsErr) && awsErr.Code() == sns.ErrCodeNotFoundException {
return topicArn, sanitizedTopic, errors.New("topic not found")
}
err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)

return topicArn, sanitizedTopic, err
}
}
Expand Down Expand Up @@ -855,7 +865,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error

topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic)
if err != nil {
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
wrappedErr := fmt.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
s.logger.Error(wrappedErr)
return wrappedErr
}

message := string(req.Data)
Expand Down
44 changes: 41 additions & 3 deletions pubsub/aws/snssqs/snssqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/stretchr/testify/require"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"

"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
Expand Down Expand Up @@ -445,8 +447,17 @@ func Test_buildARN_DefaultPartition(t *testing.T) {
r := require.New(t)
l := logger.NewLogger("SnsSqs unit test")
l.SetOutputLevel(logger.DebugLevel)
mockAuthProvider := &awsAuth.StaticAuth{}
mockedSnssqs := &awsAuth.SnsSqsClients{}
mockedSnssqs.SetRegion("r")
mockedClients := awsAuth.Clients{
Snssqs: mockedSnssqs,
}
mockAuthProvider.WithMockClients(&mockedClients)

ps := snsSqs{
logger: l,
logger: l,
authProvider: mockAuthProvider,
}

md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
Expand All @@ -459,6 +470,9 @@ func Test_buildARN_DefaultPartition(t *testing.T) {
md.AccountID = "123456789012"
ps.metadata = md

// This is now set in the Init func, so we must set it in test to have the value to build the arn with
ps.metadata.internalPartition = "aws"

arn := ps.buildARN("sns", "myTopic")
r.Equal("arn:aws:sns:r:123456789012:myTopic", arn)
}
Expand All @@ -468,8 +482,17 @@ func Test_buildARN_StandardPartition(t *testing.T) {
r := require.New(t)
l := logger.NewLogger("SnsSqs unit test")
l.SetOutputLevel(logger.DebugLevel)
mockAuthProvider := &awsAuth.StaticAuth{}
mockedSnssqs := &awsAuth.SnsSqsClients{}
mockedSnssqs.SetRegion("us-west-2")
mockedClients := awsAuth.Clients{
Snssqs: mockedSnssqs,
}
mockAuthProvider.WithMockClients(&mockedClients)

ps := snsSqs{
logger: l,
logger: l,
authProvider: mockAuthProvider,
}

md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
Expand All @@ -482,6 +505,9 @@ func Test_buildARN_StandardPartition(t *testing.T) {
md.AccountID = "123456789012"
ps.metadata = md

// This is now set in the Init func, so we must set it in test to have the value to build the arn with
ps.metadata.internalPartition = "aws"

arn := ps.buildARN("sns", "myTopic")
r.Equal("arn:aws:sns:us-west-2:123456789012:myTopic", arn)
}
Expand All @@ -491,8 +517,17 @@ func Test_buildARN_NonStandardPartition(t *testing.T) {
r := require.New(t)
l := logger.NewLogger("SnsSqs unit test")
l.SetOutputLevel(logger.DebugLevel)
mockAuthProvider := &awsAuth.StaticAuth{}
mockedSnssqs := &awsAuth.SnsSqsClients{}
mockedSnssqs.SetRegion("cn-northwest-1")
mockedClients := awsAuth.Clients{
Snssqs: mockedSnssqs,
}
mockAuthProvider.WithMockClients(&mockedClients)

ps := snsSqs{
logger: l,
logger: l,
authProvider: mockAuthProvider,
}

md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
Expand All @@ -505,6 +540,9 @@ func Test_buildARN_NonStandardPartition(t *testing.T) {
md.AccountID = "123456789012"
ps.metadata = md

// This is now set in the Init func, so we must set it in test to have the value to build the arn with
ps.metadata.internalPartition = "aws-cn"

arn := ps.buildARN("sns", "myTopic")
r.Equal("arn:aws-cn:sns:cn-northwest-1:123456789012:myTopic", arn)
}
Loading