diff --git a/testutil/testutil.go b/testutil/testutil.go index 319702141..27ab8455a 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -907,3 +907,28 @@ func GetTestCertPool(t *testing.T, cert []byte) *x509.CertPool { } return pool } + +type TestRetryHandler struct { + Requests int + Retries int + OKAtCount int + RespData []byte + RetryStatus int +} + +func (r *TestRetryHandler) Handler() http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + if r.Requests > 0 { + r.Retries++ + } + + r.Requests++ + if r.OKAtCount > 0 && (r.Requests == r.OKAtCount) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(r.RespData) + return + } else { + w.WriteHeader(r.RetryStatus) + } + } +} diff --git a/util/util.go b/util/util.go index e6f58e3c4..19954494b 100644 --- a/util/util.go +++ b/util/util.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v4" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/vault/api" @@ -405,3 +406,63 @@ func Remount(d *schema.ResourceData, client *api.Client, mountField string, isAu return ret, nil } + +type RetryRequestOpts struct { + MaxTries uint64 + Delay time.Duration + StatusCodes []int +} + +func (r *RetryRequestOpts) IsRetryableStatus(statusCode int) bool { + for _, s := range r.StatusCodes { + if s == statusCode { + return true + } + } + + return false +} + +func DefaultRequestOpts() *RetryRequestOpts { + return &RetryRequestOpts{ + MaxTries: 60, + Delay: time.Millisecond * 500, + StatusCodes: []int{http.StatusBadRequest}, + } +} + +// RetryWrite attempts to retry a Logical.Write() to Vault for the +// RetryRequestOpts. Primary useful for handling some of Vault's eventually +// consistent APIs. +func RetryWrite(client *api.Client, path string, data map[string]interface{}, req *RetryRequestOpts) (*api.Secret, error) { + if req == nil { + req = DefaultRequestOpts() + } + + if path == "" { + return nil, fmt.Errorf("path is empty") + } + + bo := backoff.NewConstantBackOff(req.Delay) + + var resp *api.Secret + return resp, backoff.RetryNotify( + func() error { + r, err := client.Logical().Write(path, data) + if err != nil { + e := fmt.Errorf("error writing to path %q, err=%w", path, err) + if respErr, ok := err.(*api.ResponseError); ok { + if req.IsRetryableStatus(respErr.StatusCode) { + return e + } + } + + return backoff.Permanent(e) + } + resp = r + return nil + }, backoff.WithMaxRetries(bo, req.MaxTries), + func(err error, duration time.Duration) { + log.Printf("[WARN] Writing to path %q failed, retrying in %s", path, duration) + }) +} diff --git a/util/util_test.go b/util/util_test.go index 7e05d5194..6ee86f299 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -4,11 +4,17 @@ package util import ( + "encoding/json" "fmt" + "net/http" "reflect" "testing" + "time" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/vault/api" + + "github.com/hashicorp/terraform-provider-vault/testutil" ) type testingStruct struct { @@ -694,3 +700,174 @@ func TestCalculateConflictsWith(t *testing.T) { }) } } + +func TestRetryWrite(t *testing.T) { + tests := []struct { + name string + path string + reqData map[string]interface{} + req *RetryRequestOpts + retryHandler *testutil.TestRetryHandler + want *api.Secret + wantErr bool + }{ + { + name: "ok-without-retries", + path: "foo/baz", + reqData: map[string]interface{}{ + "qux": "baz", + }, + req: &RetryRequestOpts{ + MaxTries: 3, + Delay: time.Millisecond * 100, + StatusCodes: []int{http.StatusBadRequest}, + }, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 1, + }, + want: &api.Secret{ + Data: map[string]interface{}{ + "qux": "baz", + }, + }, + wantErr: false, + }, + { + name: "ok-with-retries", + path: "foo/baz", + reqData: map[string]interface{}{ + "baz": "biff", + }, + req: &RetryRequestOpts{ + MaxTries: 3, + Delay: time.Millisecond * 100, + StatusCodes: []int{http.StatusBadRequest}, + }, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 2, + RetryStatus: http.StatusBadRequest, + }, + want: &api.Secret{ + Data: map[string]interface{}{ + "baz": "biff", + }, + }, + wantErr: false, + }, + { + name: "non-retryable-no-status", + path: "foo/baz", + reqData: map[string]interface{}{ + "baz": "biff", + }, + req: &RetryRequestOpts{ + MaxTries: 3, + Delay: time.Millisecond * 100, + StatusCodes: []int{}, + }, + retryHandler: &testutil.TestRetryHandler{ + RetryStatus: http.StatusConflict, + }, + want: nil, + wantErr: true, + }, + { + name: "max-retries-exceeded-single", + path: "foo/baz", + reqData: map[string]interface{}{ + "baz": "biff", + }, + req: &RetryRequestOpts{ + MaxTries: 3, + Delay: time.Millisecond * 100, + StatusCodes: []int{http.StatusBadRequest}, + }, + retryHandler: &testutil.TestRetryHandler{ + RetryStatus: http.StatusBadRequest, + }, + want: nil, + wantErr: true, + }, + { + name: "max-retries-exceeded-choices", + path: "foo/baz", + reqData: map[string]interface{}{ + "baz": "biff", + }, + req: &RetryRequestOpts{ + MaxTries: 3, + Delay: time.Millisecond * 100, + StatusCodes: []int{http.StatusBadRequest, http.StatusConflict}, + }, + retryHandler: &testutil.TestRetryHandler{ + RetryStatus: http.StatusConflict, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config, ln := testutil.TestHTTPServer(t, tt.retryHandler.Handler()) + defer ln.Close() + + config.Address = fmt.Sprintf("http://%s", ln.Addr()) + client, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + + if !tt.wantErr && tt.retryHandler.RespData == nil { + b, err := json.Marshal(tt.reqData) + if err != nil { + t.Fatal(err) + } + tt.retryHandler.RespData = b + } + + got, err := RetryWrite(client, tt.path, tt.reqData, tt.req) + if (err != nil) != tt.wantErr { + t.Errorf("RetryWrite() error = %v, wantErr %v", err, + tt.wantErr) + return + } + + if tt.wantErr { + if len(tt.req.StatusCodes) == 0 { + if tt.retryHandler.Retries != 0 { + t.Fatalf("expected 0 retries, actual %d", + tt.retryHandler.Retries) + } + if tt.retryHandler.Requests != 1 { + t.Fatalf("expected 1 requests, actual %d", + tt.retryHandler.Requests) + } + } else { + if int(tt.req.MaxTries) != tt.retryHandler.Retries { + t.Fatalf("expected %d retries, actual %d", + tt.req.MaxTries, tt.retryHandler.Requests) + } + } + } else { + if tt.retryHandler.OKAtCount != tt.retryHandler.Requests { + t.Fatalf("expected %d retries, actual %d", + tt.retryHandler.OKAtCount, tt.retryHandler.Requests) + } + + var expectedRetries int + if tt.retryHandler.OKAtCount > 1 { + expectedRetries = tt.retryHandler.Requests - 1 + } + + if expectedRetries != tt.retryHandler.Retries { + t.Fatalf("expected %d retries, actual %d", + expectedRetries, tt.retryHandler.Requests) + } + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("RetryWrite() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/vault/resource_generic_secret.go b/vault/resource_generic_secret.go index 9e880dd0c..365541fb6 100644 --- a/vault/resource_generic_secret.go +++ b/vault/resource_generic_secret.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/terraform-provider-vault/internal/consts" "github.com/hashicorp/terraform-provider-vault/internal/provider" + "github.com/hashicorp/terraform-provider-vault/util" ) const latestSecretVersion = -1 @@ -148,7 +149,7 @@ func genericSecretResourceWrite(d *schema.ResourceData, meta interface{}) error } - if err := writeSecretDataWithRetry(client, path, data); err != nil { + if _, err := util.RetryWrite(client, path, data, util.DefaultRequestOpts()); err != nil { return err } diff --git a/vault/resource_identity_entity_test.go b/vault/resource_identity_entity_test.go index 8a649e61c..19e0ca477 100644 --- a/vault/resource_identity_entity_test.go +++ b/vault/resource_identity_entity_test.go @@ -224,34 +224,34 @@ func TestReadEntity(t *testing.T) { maxRetries int expectedRetries int wantError error - retryHandler *testRetryHandler + retryHandler *testutil.TestRetryHandler }{ { name: "retry-none", - retryHandler: &testRetryHandler{ - okAtCount: 1, - // retryStatus: http.StatusNotFound, - respData: []byte(`{"data": {"foo": "baz"}}`), + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 1, + // RetryStatus: http.StatusNotFound, + RespData: []byte(`{"data": {"foo": "baz"}}`), }, maxRetries: 4, expectedRetries: 0, }, { name: "retry-ok-404", - retryHandler: &testRetryHandler{ - okAtCount: 3, - retryStatus: http.StatusNotFound, - respData: []byte(`{"data": {"foo": "baz"}}`), + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 3, + RetryStatus: http.StatusNotFound, + RespData: []byte(`{"data": {"foo": "baz"}}`), }, maxRetries: 4, expectedRetries: 2, }, { name: "retry-ok-412", - retryHandler: &testRetryHandler{ - okAtCount: 3, - retryStatus: http.StatusPreconditionFailed, - respData: []byte(`{"data": {"foo": "baz"}}`), + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 3, + RetryStatus: http.StatusPreconditionFailed, + RespData: []byte(`{"data": {"foo": "baz"}}`), }, maxRetries: 4, expectedRetries: 2, @@ -259,9 +259,9 @@ func TestReadEntity(t *testing.T) { { name: "retry-exhausted-default-max-404", path: entity.JoinEntityID("retry-exhausted-default-max-404"), - retryHandler: &testRetryHandler{ - okAtCount: 0, - retryStatus: http.StatusNotFound, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 0, + RetryStatus: http.StatusNotFound, }, maxRetries: DefaultMaxHTTPRetriesCCC, expectedRetries: DefaultMaxHTTPRetriesCCC, @@ -271,9 +271,9 @@ func TestReadEntity(t *testing.T) { { name: "retry-exhausted-default-max-412", path: entity.JoinEntityID("retry-exhausted-default-max-412"), - retryHandler: &testRetryHandler{ - okAtCount: 0, - retryStatus: http.StatusPreconditionFailed, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 0, + RetryStatus: http.StatusPreconditionFailed, }, maxRetries: DefaultMaxHTTPRetriesCCC, expectedRetries: DefaultMaxHTTPRetriesCCC, @@ -283,9 +283,9 @@ func TestReadEntity(t *testing.T) { { name: "retry-exhausted-custom-max-404", path: entity.JoinEntityID("retry-exhausted-custom-max-404"), - retryHandler: &testRetryHandler{ - okAtCount: 0, - retryStatus: http.StatusNotFound, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 0, + RetryStatus: http.StatusNotFound, }, maxRetries: 5, expectedRetries: 5, @@ -295,9 +295,9 @@ func TestReadEntity(t *testing.T) { { name: "retry-exhausted-custom-max-412", path: entity.JoinEntityID("retry-exhausted-custom-max-412"), - retryHandler: &testRetryHandler{ - okAtCount: 0, - retryStatus: http.StatusPreconditionFailed, + retryHandler: &testutil.TestRetryHandler{ + OKAtCount: 0, + RetryStatus: http.StatusPreconditionFailed, }, maxRetries: 5, expectedRetries: 5, @@ -315,7 +315,7 @@ func TestReadEntity(t *testing.T) { r := tt.retryHandler - config, ln := testutil.TestHTTPServer(t, r.handler()) + config, ln := testutil.TestHTTPServer(t, r.Handler()) defer ln.Close() config.Address = fmt.Sprintf("http://%s", ln.Addr()) @@ -342,7 +342,7 @@ func TestReadEntity(t *testing.T) { t.Errorf("expected err %q, actual %q", tt.wantError, err) } - if tt.retryHandler.retryStatus == http.StatusNotFound { + if tt.retryHandler.RetryStatus == http.StatusNotFound { if !group.IsIdentityNotFoundError(err) { t.Errorf("expected an errEntityNotFound err %q, actual %q", entity.ErrEntityNotFound, err) } @@ -353,8 +353,8 @@ func TestReadEntity(t *testing.T) { } var data map[string]interface{} - if err := json.Unmarshal(tt.retryHandler.respData, &data); err != nil { - t.Fatalf("invalid test data %#v, err=%s", tt.retryHandler.respData, err) + if err := json.Unmarshal(tt.retryHandler.RespData, &data); err != nil { + t.Fatalf("invalid test data %#v, err=%s", tt.retryHandler.RespData, err) } expectedResp := &api.Secret{ @@ -366,9 +366,8 @@ func TestReadEntity(t *testing.T) { } } - retries := r.requests - 1 - if tt.expectedRetries != retries { - t.Fatalf("expected %d retries, actual %d", tt.expectedRetries, retries) + if tt.expectedRetries != r.Retries { + t.Fatalf("expected %d retries, actual %d", tt.expectedRetries, r.Retries) } }) } @@ -405,23 +404,3 @@ func TestIsEntityNotFoundError(t *testing.T) { }) } } - -type testRetryHandler struct { - requests int - okAtCount int - respData []byte - retryStatus int -} - -func (t *testRetryHandler) handler() http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - t.requests++ - if t.okAtCount > 0 && (t.requests >= t.okAtCount) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(t.respData) - return - } else { - w.WriteHeader(t.retryStatus) - } - } -} diff --git a/vault/resource_kv_secret_backend_v2.go b/vault/resource_kv_secret_backend_v2.go index 8e50335b6..aa6eab24a 100644 --- a/vault/resource_kv_secret_backend_v2.go +++ b/vault/resource_kv_secret_backend_v2.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/terraform-provider-vault/internal/consts" "github.com/hashicorp/terraform-provider-vault/internal/provider" + "github.com/hashicorp/terraform-provider-vault/util" ) func kvSecretBackendV2Resource() *schema.Resource { @@ -72,7 +73,7 @@ func kvSecretBackendV2CreateUpdate(ctx context.Context, d *schema.ResourceData, } path := mount + "/config" - if _, err := client.Logical().Write(path, data); err != nil { + if _, err := util.RetryWrite(client, path, data, util.DefaultRequestOpts()); err != nil { return diag.Errorf("error writing config data to %s, err=%s", path, err) } diff --git a/vault/resource_kv_secret_v2.go b/vault/resource_kv_secret_v2.go index 86c68db62..f35368644 100644 --- a/vault/resource_kv_secret_v2.go +++ b/vault/resource_kv_secret_v2.go @@ -8,17 +8,16 @@ import ( "encoding/json" "fmt" "log" - "net/http" "regexp" "time" - "github.com/cenkalti/backoff/v4" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/vault/api" "github.com/hashicorp/terraform-provider-vault/internal/consts" "github.com/hashicorp/terraform-provider-vault/internal/provider" + "github.com/hashicorp/terraform-provider-vault/util" ) var ( @@ -203,7 +202,7 @@ func kvSecretV2Write(ctx context.Context, d *schema.ResourceData, meta interface data[k] = d.Get(k) } - if err := writeSecretDataWithRetry(client, path, data); err != nil { + if _, err := util.RetryWrite(client, path, data, util.DefaultRequestOpts()); err != nil { return diag.FromErr(err) } @@ -390,21 +389,3 @@ func getKVV2SecretMountFromPath(path string) (string, error) { } return res[1], nil } - -func writeSecretDataWithRetry(client *api.Client, path string, data map[string]interface{}) error { - return backoff.RetryNotify( - func() error { - if _, err := client.Logical().Write(path, data); err != nil { - e := fmt.Errorf("error writing secret data: %w", err) - if respErr, ok := err.(*api.ResponseError); ok && (respErr.StatusCode == http.StatusBadRequest) { - return e - } - - return backoff.Permanent(e) - } - return nil - }, backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Millisecond*500), 10), - func(err error, duration time.Duration) { - log.Printf("[WARN] Writing secret data to %q failed, retrying in %s", path, duration) - }) -}