diff --git a/_integration-tests/tests/ibm_sarama/ibm_sarama.go b/_integration-tests/tests/ibm_sarama/ibm_sarama.go index dfb3b2ea..0320d93c 100644 --- a/_integration-tests/tests/ibm_sarama/ibm_sarama.go +++ b/_integration-tests/tests/ibm_sarama/ibm_sarama.go @@ -53,8 +53,6 @@ func produceMessage(t *testing.T, addrs []string, cfg *sarama.Config) { err := backoff.Retry( context.Background(), backoff.NewConstantStrategy(50*time.Millisecond), - 3, - nil, func() (err error) { defer func() { if r := recover(); r != nil && err == nil { @@ -68,6 +66,7 @@ func produceMessage(t *testing.T, addrs []string, cfg *sarama.Config) { producer, err = sarama.NewSyncProducer(addrs, cfg) return err }, + &backoff.RetryOptions{MaxAttempts: 3}, ) require.NoError(t, err, "failed to create producer") diff --git a/_integration-tests/tests/segmentio_kafka.v0/segmentio_kafka.go b/_integration-tests/tests/segmentio_kafka.v0/segmentio_kafka.go index 833a8803..6fed60d1 100644 --- a/_integration-tests/tests/segmentio_kafka.v0/segmentio_kafka.go +++ b/_integration-tests/tests/segmentio_kafka.v0/segmentio_kafka.go @@ -88,15 +88,17 @@ func (tc *TestCase) produce(t *testing.T) { err := backoff.Retry( ctx, backoff.NewExponentialStrategy(100*time.Millisecond, 2, 5*time.Second), - 10, - func(err error, attempt int, delay time.Duration) bool { - if !errors.Is(err, kafka.UnknownTopicOrPartition) { - return false - } - t.Logf("failed to produce kafka messages, will retry in %s (attempt left: %d)", delay, 10-attempt) - return true - }, func() error { return tc.writer.WriteMessages(ctx, messages...) }, + &backoff.RetryOptions{ + MaxAttempts: 10, + ShouldRetry: func(err error, attempt int, delay time.Duration) bool { + if !errors.Is(err, kafka.UnknownTopicOrPartition) { + return false + } + t.Logf("failed to produce kafka messages, will retry in %s (attempt left: %d)", delay, 10-attempt) + return true + }, + }, ) require.NoError(t, err) require.NoError(t, tc.writer.Close()) diff --git a/_integration-tests/utils/backoff/backoff.go b/_integration-tests/utils/backoff/backoff.go index 5599429b..f983cb48 100644 --- a/_integration-tests/utils/backoff/backoff.go +++ b/_integration-tests/utils/backoff/backoff.go @@ -13,6 +13,7 @@ package backoff import ( "context" "errors" + "math" "time" ) @@ -21,32 +22,67 @@ type Strategy interface { Next() time.Duration } -// Retry makes up to [maxAttempts] at calling the [action] function. It uses the -// [strategy] to determine how much time to wait between attempts. The -// [shouldRetry] functionis called with all non-[nil] errors returned by -// [action] and the retry delay before the next attempt, and should return -// [true] if the error is transient and should be retried, [false] if [Retry] -// should return immediately. If [shouldRetry] is [nil], all errors are retried. -func Retry( - ctx context.Context, - strategy Strategy, - maxAttempts int, - shouldRetry func(error, int, time.Duration) bool, - action func() error, -) error { - return doRetry(ctx, strategy, maxAttempts, shouldRetry, action, time.Sleep) +const ( + defaultMaxAttempts = 10 +) + +// RetryAllErrors is the default function used by [RetryOptions.ShouldRetry]. It +// returns [true] regardless of its arguments. +func RetryAllErrors(error, int, time.Duration) bool { + return true +} + +type RetryOptions struct { + // MaxAttempts is the maximum number of attempts to make before giving up. If + // it is negative, there is no limit to the number of attempts (it will be set + // to [math.MaxInt]); if it is zero, the default value of 10 will be used. It + // is fine (although a little silly) to set [RetryOptions.MaxAttempts] to 1. + MaxAttempts int + // ShouldRetry is called with the error returned by the action, the attempt + // number, and the delay before the next attempt could be made. If it returns + // [true], the next attempt will be made; otherwise, the [Retry] function will + // immediately return. If [nil], the default [RetryAllErrors] function will be + // used. + ShouldRetry func(err error, attempt int, delay time.Duration) bool + // Sleep is the function used to wait in between attempts. It is intended to + // be used in testing. If [nil], the default [time.Sleep] function will be + // used. + Sleep func(time.Duration) } -func doRetry( +// Retry makes up to [RetryOptions.MaxAttempts] at calling the [action] +// function. It uses the [Strategy] to determine how much time to wait between +// attempts. The [RetryOptions.ShouldRetry] function is called with all +// non-[nil] errors returned by [action], the attempt number, and the delay +// before the next attempt. If it returns [true], the [RetryOptions.Sleep] +// function is called with the delay, and the next attempt is made. Otherwise, +// [Retry] returns immediately. +func Retry( ctx context.Context, strategy Strategy, - maxAttempts int, - shouldRetry func(error, int, time.Duration) bool, action func() error, - sleep func(time.Duration), + opts *RetryOptions, ) error { - var errs error + var ( + maxAttempts = defaultMaxAttempts + shouldRetry = RetryAllErrors + sleep = time.Sleep + ) + if opts != nil { + if opts.MaxAttempts > 0 { + maxAttempts = opts.MaxAttempts + } else if opts.MaxAttempts < 0 { + maxAttempts = math.MaxInt + } + if opts.ShouldRetry != nil { + shouldRetry = opts.ShouldRetry + } + if opts.Sleep != nil { + sleep = opts.Sleep + } + } + var errs error for attempt, delay := 0, time.Duration(0); attempt < maxAttempts && ctx.Err() == nil; attempt, delay = attempt+1, strategy.Next() { if delay > 0 { sleep(delay) @@ -65,6 +101,5 @@ func doRetry( break } } - return errors.Join(errs, ctx.Err()) } diff --git a/_integration-tests/utils/backoff/backoff_test.go b/_integration-tests/utils/backoff/backoff_test.go index 7f7652d6..9a3a5e88 100644 --- a/_integration-tests/utils/backoff/backoff_test.go +++ b/_integration-tests/utils/backoff/backoff_test.go @@ -8,6 +8,7 @@ package backoff import ( "context" "fmt" + "math/rand" "strings" "testing" "time" @@ -47,7 +48,7 @@ func TestRetry(t *testing.T) { delays = append(delays, d) } - err := doRetry(ctx, strategy, maxAttempts, nil, action, timeSleep) + err := Retry(ctx, strategy, action, &RetryOptions{MaxAttempts: maxAttempts, Sleep: timeSleep}) require.Error(t, err) assert.Equal(t, delaySequence, delays) for _, expectedErr := range expectedErrs { @@ -73,7 +74,7 @@ func TestRetry(t *testing.T) { delays = append(delays, d) } - err := doRetry(ctx, strategy, maxAttempts, shouldRetry, action, timeSleep) + err := Retry(ctx, strategy, action, &RetryOptions{MaxAttempts: maxAttempts, ShouldRetry: shouldRetry, Sleep: timeSleep}) require.Error(t, err) // We hit the non-retryable error at the 3rd attempt. assert.Equal(t, delaySequence[:2], delays) @@ -108,28 +109,51 @@ func TestRetry(t *testing.T) { } } - err := doRetry(ctx, strategy, maxAttempts, nil, action, timeSleep) + err := Retry(ctx, strategy, action, &RetryOptions{MaxAttempts: maxAttempts, Sleep: timeSleep}) require.Error(t, err) // We reach the 1 second total waited during the 4th back-off. assert.Equal(t, delaySequence[:4], delays) for _, expectedErr := range expectedErrs { - assert.ErrorIs(t, err, expectedErr) + require.ErrorIs(t, err, expectedErr) + } + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("unlimited retries", func(t *testing.T) { + ctx := context.Background() + strategy := NewConstantStrategy(100 * time.Millisecond) + var attempts int + action := func() error { + attempts++ + // At least 20 errors, then flip a coin... but no more than 100 attempts. + if attempts < 20 || (attempts < 100 && rand.Int()%2 == 0) { + return fmt.Errorf("Error number %d", attempts) + } + return nil + } + var delayCount int + timeSleep := func(time.Duration) { + delayCount++ } - assert.ErrorIs(t, err, context.Canceled) + + err := Retry(ctx, strategy, action, &RetryOptions{MaxAttempts: -1, Sleep: timeSleep}) + require.NoError(t, err) + // We should have waited as many times as we attempted, except for the initial attempt. + assert.Equal(t, delayCount, attempts-1) }) t.Run("immediate success", func(t *testing.T) { ctx := context.Background() strategy := NewExponentialStrategy(100*time.Millisecond, 2, 5*time.Second) maxAttempts := 10 - shouldRetry := func(err error, _ int, _ time.Duration) bool { return false } + shouldRetry := func(error, int, time.Duration) bool { return false } action := func() error { return nil } delays := make([]time.Duration, 0, maxAttempts) timeSleep := func(d time.Duration) { delays = append(delays, d) } - err := doRetry(ctx, strategy, maxAttempts, shouldRetry, action, timeSleep) + err := Retry(ctx, strategy, action, &RetryOptions{MaxAttempts: maxAttempts, ShouldRetry: shouldRetry, Sleep: timeSleep}) require.NoError(t, err) assert.Empty(t, delays) })