Skip to content

Commit

Permalink
fix(tests): update client for tests
Browse files Browse the repository at this point in the history
Signed-off-by: Samantha Coyle <[email protected]>
  • Loading branch information
sicoyle committed Dec 10, 2024
1 parent 8531353 commit 6aae517
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 15 deletions.
15 changes: 11 additions & 4 deletions common/authentication/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Clients struct {
Dynamo *DynamoDBClients
sns *SnsClients
sqs *SqsClients
snssqs *SnsSqsClients
Snssqs *SnsSqsClients
Secret *SecretManagerClients
ParameterStore *ParameterStoreClients
kinesis *KinesisClients
Expand All @@ -75,8 +75,8 @@ func (c *Clients) refresh(session *session.Session) error {
c.sns.New(session)
case c.sqs != nil:
c.sqs.New(session)
case c.snssqs != nil:
c.snssqs.New(session)
case c.Snssqs != nil:
c.Snssqs.New(session)
case c.Secret != nil:
c.Secret.New(session)
case c.ParameterStore != nil:
Expand Down Expand Up @@ -170,11 +170,18 @@ func (c *SnsClients) New(session *session.Session) {
c.Sns = sns.New(session, session.Config)
}

func (c *SnsSqsClients) SetRegion(region string) {
c.region = region
}

func (c *SnsSqsClients) New(session *session.Session) {
c.Sns = sns.New(session, session.Config)
c.Sqs = sqs.New(session, session.Config)
c.Sts = sts.New(session, session.Config)
c.region = *session.Config.Region
// the empty string check is added to allow for this to be injected by tests
if c.region != "" {
c.SetRegion(*session.Config.Region)
}
}

func (c *SqsClients) New(session *session.Session) {
Expand Down
10 changes: 5 additions & 5 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ func (a *StaticAuth) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.clients.snssqs != nil {
return a.clients.snssqs
if a.clients.Snssqs != nil {
return a.clients.Snssqs
}

clients := SnsSqsClients{}
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
a.clients.Snssqs = &clients
a.clients.Snssqs.New(a.session)
return a.clients.Snssqs
}

func (a *StaticAuth) SecretManager() *SecretManagerClients {
Expand Down
10 changes: 5 additions & 5 deletions common/authentication/aws/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ func (a *x509) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.clients.snssqs != nil {
return a.clients.snssqs
if a.clients.Snssqs != nil {
return a.clients.Snssqs
}

clients := SnsSqsClients{}
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
a.clients.Snssqs = &clients
a.clients.Snssqs.New(a.session)
return a.clients.Snssqs
}

func (a *x509) SecretManager() *SecretManagerClients {
Expand Down
5 changes: 5 additions & 0 deletions pubsub/aws/snssqs/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
return nil, err
}

// set an initial default that we will override if an aws region is found after we init the snssqs via the aws auth provider in Init()
if md.internalPartition == "" {
md.internalPartition = "aws"
}

if md.SqsQueueName == "" {
return nil, errors.New("consumerID must be set")
}
Expand Down
13 changes: 12 additions & 1 deletion pubsub/aws/snssqs/snssqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/stretchr/testify/require"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"

"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
Expand Down Expand Up @@ -445,8 +447,17 @@ func Test_buildARN_DefaultPartition(t *testing.T) {
r := require.New(t)
l := logger.NewLogger("SnsSqs unit test")
l.SetOutputLevel(logger.DebugLevel)
mockAuthProvider := &awsAuth.StaticAuth{}
mockedSnssqs := &awsAuth.SnsSqsClients{}
mockedSnssqs.SetRegion("r")
mockedClients := awsAuth.Clients{
Snssqs: mockedSnssqs,
}
mockAuthProvider.WithMockClients(&mockedClients)

ps := snsSqs{
logger: l,
logger: l,
authProvider: mockAuthProvider,
}

md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
Expand Down

0 comments on commit 6aae517

Please sign in to comment.