diff --git a/aws/internal/tfresource/retry.go b/aws/internal/tfresource/retry.go index 517a6b21933..544b2d68879 100644 --- a/aws/internal/tfresource/retry.go +++ b/aws/internal/tfresource/retry.go @@ -7,6 +7,39 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" ) +// RetryUntilFound retries the specified function until the underlying resource is found. +// The function returns a resource.NotFoundError to indicate that the underlying resource does not exist. +// If the retries time out, the function is called one last time. +func RetryUntilFound(timeout time.Duration, f func() (interface{}, error)) (interface{}, error) { + var output interface{} + + err := resource.Retry(timeout, func() *resource.RetryError { + var err error + + output, err = f() + + if NotFound(err) { + return resource.RetryableError(err) + } + + if err != nil { + return resource.NonRetryableError(err) + } + + return nil + }) + + if TimedOut(err) { + output, err = f() + } + + if err != nil { + return nil, err + } + + return output, err +} + // RetryWhenAwsErrCodeEquals retries the specified function when it returns one of the specified AWS error code. func RetryWhenAwsErrCodeEquals(timeout time.Duration, f func() (interface{}, error), codes ...string) (interface{}, error) { var output interface{} diff --git a/aws/internal/tfresource/retry_test.go b/aws/internal/tfresource/retry_test.go index 8a396b9e3de..ecb8b14907f 100644 --- a/aws/internal/tfresource/retry_test.go +++ b/aws/internal/tfresource/retry_test.go @@ -7,9 +7,65 @@ import ( "time" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) +func TestRetryUntilFound(t *testing.T) { + var retryCount int32 + + testCases := []struct { + Name string + F func() (interface{}, error) + ExpectError bool + }{ + { + Name: "no error", + F: func() (interface{}, error) { + return nil, nil + }, + }, + { + Name: "non-retryable other error", + F: func() (interface{}, error) { + return nil, errors.New("TestCode") + }, + ExpectError: true, + }, + { + Name: "retryable not-found error timeout", + F: func() (interface{}, error) { + return nil, &resource.NotFoundError{} + }, + ExpectError: true, + }, + { + Name: "retryable AWS error success", + F: func() (interface{}, error) { + if atomic.CompareAndSwapInt32(&retryCount, 0, 1) { + return nil, &resource.NotFoundError{} + } + + return nil, nil + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + retryCount = 0 + + _, err := tfresource.RetryUntilFound(5*time.Second, testCase.F) + + if testCase.ExpectError && err == nil { + t.Fatal("expected error") + } else if !testCase.ExpectError && err != nil { + t.Fatalf("unexpected error: %s", err) + } + }) + } +} + func TestRetryWhenAwsErrCodeEquals(t *testing.T) { var retryCount int32