Skip to content

Commit

Permalink
fix(snssqs): if topic dne and we do not create, then we skip the actu…
Browse files Browse the repository at this point in the history
…al err

Signed-off-by: Samantha Coyle <[email protected]>
  • Loading branch information
sicoyle committed Dec 10, 2024
1 parent 026ae76 commit 901420e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
7 changes: 7 additions & 0 deletions common/authentication/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ type SnsSqsClients struct {
Sns *sns.SNS
Sqs *sqs.SQS
Sts *sts.STS

region string
}

func (c *SnsSqsClients) Region() string {
return c.region
}

type SnsClients struct {
Expand Down Expand Up @@ -168,6 +174,7 @@ func (c *SnsSqsClients) New(session *session.Session) {
c.Sns = sns.New(session, session.Config)
c.Sqs = sqs.New(session, session.Config)
c.Sts = sts.New(session, session.Config)
c.region = *session.Config.Region
}

func (c *SqsClients) New(session *session.Session) {
Expand Down
3 changes: 3 additions & 0 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ func (a *StaticAuth) UpdatePostgres(ctx context.Context, poolConfig *pgxpool.Con
if err != nil {
return fmt.Errorf("failed to get database token: %w", err)
}
if pwd == "" {
return errors.New("failed to get a valid password for the database token for AWS IAM authentication")
}
pgConfig.Password = pwd
poolConfig.ConnConfig.Password = pwd

Expand Down
10 changes: 0 additions & 10 deletions pubsub/aws/snssqs/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (

"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/metadata"

"github.com/aws/aws-sdk-go/aws/endpoints"
)

type snsSqsMetadata struct {
Expand Down Expand Up @@ -84,14 +82,6 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
return nil, err
}

if md.Region != "" {
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), md.Region); ok {
md.internalPartition = partition.ID()
} else {
md.internalPartition = "aws"
}
}

if md.SqsQueueName == "" {
return nil, errors.New("consumerID must be set")
}
Expand Down
27 changes: 20 additions & 7 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ import (

"github.com/dapr/kit/retry"

gonanoid "github.com/matoous/go-nanoid/v2"

"github.com/aws/aws-sdk-go/aws/endpoints"
awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
gonanoid "github.com/matoous/go-nanoid/v2"
)

type snsSqs struct {
Expand Down Expand Up @@ -162,6 +162,14 @@ func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error {
s.authProvider = provider
}

if s.authProvider.SnsSqs().Region() != "" {
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), s.authProvider.SnsSqs().Region()); ok {
m.internalPartition = partition.ID()
} else {
m.internalPartition = "aws"
}
}

s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second))

err = s.setAwsAccountIDIfNotProvided(ctx)
Expand Down Expand Up @@ -201,7 +209,7 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error {
}

func (s *snsSqs) buildARN(serviceName, entityName string) string {
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.metadata.Region, s.metadata.AccountID, entityName)
return fmt.Sprintf("arn:%s:%s:%s:%s:%s", s.metadata.internalPartition, serviceName, s.authProvider.SnsSqs().Region(), s.metadata.AccountID, entityName)
}

func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, error) {
Expand Down Expand Up @@ -257,20 +265,23 @@ func (s *snsSqs) getOrCreateTopic(ctx context.Context, topic string) (topicArn s
}

// creating queues is idempotent, the names serve as unique keys among a given region.
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)

if !s.metadata.DisableEntityManagement {
s.logger.Debugf("No SNS topic ARN found for topic: %s. creating SNS with (sanitized) topic: %s", topic, sanitizedTopic)
topicArn, err = s.createTopic(ctx, sanitizedTopic)
if err != nil {
err = fmt.Errorf("error creating new (sanitized) topic '%s': %w", topic, err)

return topicArn, sanitizedTopic, err
}
} else {
s.logger.Debugf("No SNS topic ARN found for topic: %s. Checking AWS SNS for if (sanitized) topic exists: %s", topic, sanitizedTopic)
topicArn, err = s.getTopicArn(ctx, sanitizedTopic)
if err != nil {
var awsErr awserr.Error
if errors.As(err, &awsErr) && awsErr.Code() == sns.ErrCodeNotFoundException {
return topicArn, sanitizedTopic, errors.New("topic not found")
}
err = fmt.Errorf("error fetching info for (sanitized) topic: %s. wrapped error is: %w", topic, err)

return topicArn, sanitizedTopic, err
}
}
Expand Down Expand Up @@ -855,7 +866,9 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error

topicArn, _, err := s.getOrCreateTopic(ctx, req.Topic)
if err != nil {
s.logger.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
wrappedErr := fmt.Errorf("error getting topic ARN for %s: %v", req.Topic, err)
s.logger.Error(wrappedErr)
return wrappedErr
}

message := string(req.Data)
Expand Down

0 comments on commit 901420e

Please sign in to comment.