Skip to content

Commit

Permalink
feat(iam auth): allow iam roles anywhere auth profile (#3591)
Browse files Browse the repository at this point in the history
Signed-off-by: Samantha Coyle <[email protected]>
Signed-off-by: Sam <[email protected]>
  • Loading branch information
sicoyle authored Nov 14, 2024
1 parent 2b924c4 commit a00a853
Show file tree
Hide file tree
Showing 31 changed files with 2,647 additions and 753 deletions.
20 changes: 19 additions & 1 deletion .build-tools/builtin-authentication-profiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,25 @@ aws:
type: string
- title: "AWS: Credentials from Environment Variables"
description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment

- title: "AWS: IAM Roles Anywhere"
description: Use X.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere.
metadata:
- name: trustAnchorArn
description: |
ARN of the AWS Trust Anchor in the AWS account granting trust to the Dapr Certificate Authority.
example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901
required: true
- name: trustProfileArn
description: |
ARN of the AWS IAM Profile in the trusting AWS account.
example: arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901
required: true
- name: assumeRoleArn
description: |
ARN of the AWS IAM role to assume in the trusting AWS account.
example: arn:aws:iam:012345678910:role/exampleIAMRoleName
required: true

azuread:
- title: "Azure AD: Managed identity"
description: Authenticate using Azure AD and a managed identity.
Expand Down
37 changes: 18 additions & 19 deletions bindings/aws/dynamodb/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import (

// DynamoDB allows performing stateful operations on AWS DynamoDB.
type DynamoDB struct {
client *dynamodb.DynamoDB
table string
logger logger.Logger
authProvider awsAuth.Provider
table string
logger logger.Logger
}

type dynamoDBMetadata struct {
Expand All @@ -51,18 +51,27 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding {
}

// Init performs connection parsing for DynamoDB.
func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error {
func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error {
meta, err := d.getDynamoDBMetadata(metadata)
if err != nil {
return err
}

client, err := d.getClient(meta)
opts := awsAuth.Options{
Logger: d.logger,
Properties: metadata.Properties,
Region: meta.Region,
Endpoint: meta.Endpoint,
AccessKey: meta.AccessKey,
SecretKey: meta.SecretKey,
SessionToken: meta.SessionToken,
}

provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts))
if err != nil {
return err
}

d.client = client
d.authProvider = provider
d.table = meta.Table

return nil
Expand All @@ -84,7 +93,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi
return nil, err
}

_, err = d.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{
_, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{
Item: item,
TableName: aws.String(d.table),
})
Expand All @@ -105,16 +114,6 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat
return &meta, nil
}

func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) {
sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}
c := dynamodb.New(sess)

return c, nil
}

// GetComponentMetadata returns the metadata of the component.
func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := dynamoDBMetadata{}
Expand All @@ -123,5 +122,5 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
}

func (d *DynamoDB) Close() error {
return nil
return d.authProvider.Close()
}
89 changes: 38 additions & 51 deletions bindings/aws/kinesis/kinesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/vmware/vmware-go-kcl/clientlibrary/config"
"github.com/vmware/vmware-go-kcl/clientlibrary/interfaces"
"github.com/vmware/vmware-go-kcl/clientlibrary/worker"

Expand All @@ -40,15 +39,16 @@ import (

// AWSKinesis allows receiving and sending data to/from AWS Kinesis stream.
type AWSKinesis struct {
client *kinesis.Kinesis
metadata *kinesisMetadata
authProvider awsAuth.Provider
metadata *kinesisMetadata

worker *worker.Worker
workerConfig *config.KinesisClientLibConfiguration
worker *worker.Worker

streamARN *string
consumerARN *string
logger logger.Logger
streamName string
consumerName string
consumerARN *string
logger logger.Logger
consumerMode string

closed atomic.Bool
closeCh chan struct{}
Expand Down Expand Up @@ -112,30 +112,25 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error
return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode)
}

client, err := a.getClient(m)
if err != nil {
return err
}
a.consumerMode = m.KinesisConsumerMode
a.streamName = m.StreamName
a.consumerName = m.ConsumerName
a.metadata = m

streamName := aws.String(m.StreamName)
stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{
StreamName: streamName,
})
opts := awsAuth.Options{
Logger: a.logger,
Properties: metadata.Properties,
Region: m.Region,
AccessKey: m.AccessKey,
SecretKey: m.SecretKey,
SessionToken: "",
}
// extra configs needed per component type
provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts))
if err != nil {
return err
}

if m.KinesisConsumerMode == SharedThroughput {
kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName,
m.StreamName, m.Region, m.ConsumerName,
client.Config.Credentials)
a.workerConfig = kclConfig
}

a.streamARN = stream.StreamDescription.StreamARN
a.metadata = m
a.client = client

a.authProvider = provider
return nil
}

Expand All @@ -148,7 +143,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
if partitionKey == "" {
partitionKey = uuid.New().String()
}
_, err := a.client.PutRecordWithContext(ctx, &kinesis.PutRecordInput{
_, err := a.authProvider.Kinesis().Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{
StreamName: &a.metadata.StreamName,
Data: req.Data,
PartitionKey: &partitionKey,
Expand All @@ -161,16 +156,15 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
if a.closed.Load() {
return errors.New("binding is closed")
}

if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig)
a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis().WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode))
err = a.worker.Start()
if err != nil {
return err
}
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
var stream *kinesis.DescribeStreamOutput
stream, err = a.client.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName})
stream, err = a.authProvider.Kinesis().Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName})
if err != nil {
return err
}
Expand All @@ -180,6 +174,10 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
}
}

stream, err := a.authProvider.Kinesis().Stream(ctx, a.streamName)
if err != nil {
return fmt.Errorf("failed to get kinesis stream arn: %v", err)
}
// Wait for context cancelation then stop
a.wg.Add(1)
go func() {
Expand All @@ -191,7 +189,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er
if a.metadata.KinesisConsumerMode == SharedThroughput {
a.worker.Shutdown()
} else if a.metadata.KinesisConsumerMode == ExtendedFanout {
a.deregisterConsumer(a.streamARN, a.consumerARN)
a.deregisterConsumer(ctx, stream, a.consumerARN)
}
}()

Expand Down Expand Up @@ -226,8 +224,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes
return
default:
}

sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
sub, err := a.authProvider.Kinesis().Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{
ConsumerARN: consumerARN,
ShardId: s.ShardId,
StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)},
Expand Down Expand Up @@ -269,14 +266,14 @@ func (a *AWSKinesis) Close() error {
close(a.closeCh)
}
a.wg.Wait()
return nil
return a.authProvider.Close()
}

func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) {
// Only set timeout on consumer call.
conCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
consumer, err := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{
ConsumerName: &a.metadata.ConsumerName,
StreamARN: streamARN,
})
Expand All @@ -288,7 +285,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st
}

func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) {
consumer, err := a.client.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{
consumer, err := a.authProvider.Kinesis().Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{
ConsumerName: &a.metadata.ConsumerName,
StreamARN: streamARN,
})
Expand All @@ -307,11 +304,11 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*
return consumer.Consumer.ConsumerARN, nil
}

func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) error {
func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error {
if a.consumerARN != nil {
// Use a background context because the running context may have been canceled already
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_, err := a.client.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{
_, err := a.authProvider.Kinesis().Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{
ConsumerARN: consumerARN,
StreamARN: streamARN,
ConsumerName: &a.metadata.ConsumerName,
Expand Down Expand Up @@ -342,7 +339,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des
tmp := *input
inCpy = &tmp
}
req, _ := a.client.DescribeStreamConsumerRequest(inCpy)
req, _ := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)

Expand All @@ -354,16 +351,6 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des
return w.WaitWithContext(ctx)
}

func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) {
sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint)
if err != nil {
return nil, err
}
k := kinesis.New(sess)

return k, nil
}

func (a *AWSKinesis) parseMetadata(meta bindings.Metadata) (*kinesisMetadata, error) {
var m kinesisMetadata
err := kitmd.DecodeMetadata(meta.Properties, &m)
Expand Down
Loading

0 comments on commit a00a853

Please sign in to comment.