Skip to content

Commit

Permalink
snssqs: fix consumer starvation (#3478)
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Chain <[email protected]>
Signed-off-by: Alessandro (Ale) Segala <[email protected]>
Co-authored-by: Alessandro (Ale) Segala <[email protected]>
Co-authored-by: Bernd Verst <[email protected]>
Co-authored-by: Yaron Schneider <[email protected]>
  • Loading branch information
4 people authored Nov 26, 2024
1 parent 2aea319 commit 1137759
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
6 changes: 6 additions & 0 deletions pubsub/aws/snssqs/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type snsSqsMetadata struct {
AccountID string `mapstructure:"accountID"`
// processing concurrency mode
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
// limits the number of concurrent goroutines
ConcurrencyLimit int `mapstructure:"concurrencyLimit"`
}

func maskLeft(s string) string {
Expand Down Expand Up @@ -130,6 +132,10 @@ func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error
return nil, err
}

if md.ConcurrencyLimit < 0 {
return nil, errors.New("concurrencyLimit must be greater than or equal to 0")
}

s.logger.Debug(md.hideDebugPrintedCredentials())

return md, nil
Expand Down
9 changes: 9 additions & 0 deletions pubsub/aws/snssqs/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ metadata:
default: '"parallel"'
example: '"single", "parallel"'
type: string
- name: concurrencyLimit
required: false
description: |
Defines the maximum number of concurrent workers handling messages.
This value is ignored when "concurrencyMode" is set to “single“.
To avoid limiting the number of concurrent workers set this to “0“.
type: number
default: '0'
example: '100'
- name: accountId
required: false
description: |
Expand Down
23 changes: 17 additions & 6 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
}

// sem is a semaphore used to control the concurrencyLimit.
// It is set only when we are in parallel mode and limit is > 0.
var sem chan (struct{}) = nil
if (s.metadata.ConcurrencyMode == pubsub.Parallel) && s.metadata.ConcurrencyLimit > 0 {
sem = make(chan struct{}, s.metadata.ConcurrencyLimit)
}

for {
// If the context is canceled, stop requesting messages
if ctx.Err() != nil {
Expand Down Expand Up @@ -629,33 +636,37 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)

var wg sync.WaitGroup
for _, message := range messageResponse.Messages {
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
continue
}

f := func(message *sqs.Message) {
defer wg.Done()
if err := s.callHandler(ctx, message, queueInfo); err != nil {
s.logger.Errorf("error while handling received message. error is: %v", err)
}
}

wg.Add(1)
switch s.metadata.ConcurrencyMode {
case pubsub.Single:
f(message)
case pubsub.Parallel:
wg.Add(1)
// This is the back pressure mechanism.
// It will block until another goroutine frees a slot.
if sem != nil {
sem <- struct{}{}
}

go func(message *sqs.Message) {
defer wg.Done()
if sem != nil {
defer func() { <-sem }()
}

f(message)
}(message)
}
}
wg.Wait()
}
}

Expand Down
17 changes: 17 additions & 0 deletions pubsub/aws/snssqs/snssqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
"consumerID": "consumer",
"Endpoint": "endpoint",
"concurrencyMode": string(pubsub.Single),
"concurrencyLimit": "42",
"accessKey": "a",
"secretKey": "s",
"sessionToken": "t",
Expand All @@ -68,6 +69,7 @@ func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) {
r.Equal("consumer", md.SqsQueueName)
r.Equal("endpoint", md.Endpoint)
r.Equal(pubsub.Single, md.ConcurrencyMode)
r.Equal(42, md.ConcurrencyLimit)
r.Equal("a", md.AccessKey)
r.Equal("s", md.SecretKey)
r.Equal("t", md.SessionToken)
Expand Down Expand Up @@ -105,6 +107,7 @@ func Test_getSnsSqsMetadata_defaults(t *testing.T) {
r.Equal("", md.SessionToken)
r.Equal("r", md.Region)
r.Equal(pubsub.Parallel, md.ConcurrencyMode)
r.Equal(0, md.ConcurrencyLimit)
r.Equal(int64(10), md.MessageVisibilityTimeout)
r.Equal(int64(10), md.MessageRetryLimit)
r.Equal(int64(2), md.MessageWaitTimeSeconds)
Expand Down Expand Up @@ -273,6 +276,20 @@ func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) {
}}},
name: "invalid message concurrencyMode",
},
// invalid concurrencyLimit
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"messageRetryLimit": "10",
"concurrencyLimit": "-1",
}}},
name: "invalid message concurrencyLimit",
},
}

l := logger.NewLogger("SnsSqs unit test")
Expand Down

0 comments on commit 1137759

Please sign in to comment.