Skip to content

Commit

Permalink
fix: allow for mocked clients without exported field
Browse files Browse the repository at this point in the history
Signed-off-by: Samantha Coyle <[email protected]>
  • Loading branch information
sicoyle committed Nov 14, 2024
1 parent a300a8c commit 0b80a39
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 279 deletions.
13 changes: 13 additions & 0 deletions common/authentication/aws/client.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
Expand Down
79 changes: 79 additions & 0 deletions common/authentication/aws/client_fake.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
)

type MockParameterStore struct {
GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error)
DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error)
ssmiface.SSMAPI
}

func (m *MockParameterStore) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) {
return m.GetParameterFn(ctx, input, option...)
}

func (m *MockParameterStore) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) {
return m.DescribeParametersFn(ctx, input, option...)
}

type MockSecretManager struct {
GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error)
secretsmanageriface.SecretsManagerAPI
}

func (m *MockSecretManager) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) {
return m.GetSecretValueFn(ctx, input, option...)
}

type MockDynamoDB struct {
GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error)
PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error)
DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error)
BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error)
TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error)
dynamodbiface.DynamoDBAPI
}

func (m *MockDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) {
return m.GetItemWithContextFn(ctx, input, op...)
}

func (m *MockDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) {
return m.PutItemWithContextFn(ctx, input, op...)
}

func (m *MockDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) {
return m.DeleteItemWithContextFn(ctx, input, op...)
}

func (m *MockDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) {
return m.BatchWriteItemWithContextFn(ctx, input, op...)
}

func (m *MockDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) {
return m.TransactWriteItemsWithContextFn(ctx, input, op...)
}
99 changes: 52 additions & 47 deletions common/authentication/aws/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type StaticAuth struct {

session *session.Session
cfg *aws.Config
Clients *Clients // exported to mock clients in unit tests
clients *Clients
}

func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
Expand All @@ -60,7 +60,7 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth
}
return GetConfig(opts)
}(),
Clients: newClients(),
clients: newClients(),
}

initialSession, err := auth.getTokenClient()
Expand All @@ -73,132 +73,137 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth
return auth, nil
}

// This is to be used only for test purposes to inject mocked clients
func (a *StaticAuth) WithMockClients(clients *Clients) {
a.clients = clients
}

func (a *StaticAuth) S3() *S3Clients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.s3 != nil {
return a.Clients.s3
if a.clients.s3 != nil {
return a.clients.s3
}

s3Clients := S3Clients{}
a.Clients.s3 = &s3Clients
a.Clients.s3.New(a.session)
return a.Clients.s3
a.clients.s3 = &s3Clients
a.clients.s3.New(a.session)
return a.clients.s3
}

func (a *StaticAuth) DynamoDB() *DynamoDBClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.Dynamo != nil {
return a.Clients.Dynamo
if a.clients.Dynamo != nil {
return a.clients.Dynamo
}

clients := DynamoDBClients{}
a.Clients.Dynamo = &clients
a.Clients.Dynamo.New(a.session)
a.clients.Dynamo = &clients
a.clients.Dynamo.New(a.session)

return a.Clients.Dynamo
return a.clients.Dynamo
}

func (a *StaticAuth) Sqs() *SqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.sqs != nil {
return a.Clients.sqs
if a.clients.sqs != nil {
return a.clients.sqs
}

clients := SqsClients{}
a.Clients.sqs = &clients
a.Clients.sqs.New(a.session)
a.clients.sqs = &clients
a.clients.sqs.New(a.session)

return a.Clients.sqs
return a.clients.sqs
}

func (a *StaticAuth) Sns() *SnsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.sns != nil {
return a.Clients.sns
if a.clients.sns != nil {
return a.clients.sns
}

clients := SnsClients{}
a.Clients.sns = &clients
a.Clients.sns.New(a.session)
return a.Clients.sns
a.clients.sns = &clients
a.clients.sns.New(a.session)
return a.clients.sns
}

func (a *StaticAuth) SnsSqs() *SnsSqsClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.snssqs != nil {
return a.Clients.snssqs
if a.clients.snssqs != nil {
return a.clients.snssqs
}

clients := SnsSqsClients{}
a.Clients.snssqs = &clients
a.Clients.snssqs.New(a.session)
return a.Clients.snssqs
a.clients.snssqs = &clients
a.clients.snssqs.New(a.session)
return a.clients.snssqs
}

func (a *StaticAuth) SecretManager() *SecretManagerClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.Secret != nil {
return a.Clients.Secret
if a.clients.Secret != nil {
return a.clients.Secret
}

clients := SecretManagerClients{}
a.Clients.Secret = &clients
a.Clients.Secret.New(a.session)
return a.Clients.Secret
a.clients.Secret = &clients
a.clients.Secret.New(a.session)
return a.clients.Secret
}

func (a *StaticAuth) ParameterStore() *ParameterStoreClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.ParameterStore != nil {
return a.Clients.ParameterStore
if a.clients.ParameterStore != nil {
return a.clients.ParameterStore
}

clients := ParameterStoreClients{}
a.Clients.ParameterStore = &clients
a.Clients.ParameterStore.New(a.session)
return a.Clients.ParameterStore
a.clients.ParameterStore = &clients
a.clients.ParameterStore.New(a.session)
return a.clients.ParameterStore
}

func (a *StaticAuth) Kinesis() *KinesisClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.kinesis != nil {
return a.Clients.kinesis
if a.clients.kinesis != nil {
return a.clients.kinesis
}

clients := KinesisClients{}
a.Clients.kinesis = &clients
a.Clients.kinesis.New(a.session)
return a.Clients.kinesis
a.clients.kinesis = &clients
a.clients.kinesis.New(a.session)
return a.clients.kinesis
}

func (a *StaticAuth) Ses() *SesClients {
a.mu.Lock()
defer a.mu.Unlock()

if a.Clients.ses != nil {
return a.Clients.ses
if a.clients.ses != nil {
return a.clients.ses
}

clients := SesClients{}
a.Clients.ses = &clients
a.Clients.ses.New(a.session)
return a.Clients.ses
a.clients.ses = &clients
a.clients.ses.New(a.session)
return a.clients.ses
}

func (a *StaticAuth) getTokenClient() (*session.Session, error) {
Expand Down
Loading

0 comments on commit 0b80a39

Please sign in to comment.