diff --git a/CHANGELOG.md b/CHANGELOG.md index 661ac6b3..194096c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Unreleased +IMPROVEMENTS: + * Added Azure API configurable retry options [GH-133](https://github.com/hashicorp/vault-plugin-auth-azure/pull/133) + ## v0.16.1 IMPROVEMENTS: * Updated dependencies: diff --git a/azure.go b/azure.go index 8da75953..798132ae 100644 --- a/azure.go +++ b/azure.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" @@ -239,6 +240,11 @@ func (p *azureProvider) getClientOptions() *arm.ClientOptions { pluginEnv: p.settings.PluginEnv, sender: p.httpClient, }, + Retry: policy.RetryOptions{ + MaxRetries: p.settings.MaxRetries, + MaxRetryDelay: p.settings.MaxRetryDelay, + RetryDelay: p.settings.RetryDelay, + }, }, } } @@ -273,17 +279,24 @@ func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) { } type azureSettings struct { - TenantID string - ClientID string - ClientSecret string - CloudConfig cloud.Configuration - GraphURI string - Resource string - PluginEnv *logical.PluginEnvironment + TenantID string + ClientID string + ClientSecret string + CloudConfig cloud.Configuration + GraphURI string + Resource string + PluginEnv *logical.PluginEnvironment + MaxRetries int32 + MaxRetryDelay time.Duration + RetryDelay time.Duration } func (b *azureAuthBackend) getAzureSettings(ctx context.Context, config *azureConfig) (*azureSettings, error) { - settings := new(azureSettings) + settings := &azureSettings{ + MaxRetries: config.MaxRetries, + MaxRetryDelay: config.MaxRetryDelay, + RetryDelay: config.RetryDelay, + } envTenantID := os.Getenv("AZURE_TENANT_ID") switch { diff --git a/path_config.go b/path_config.go index 3a8a19fd..fd57abb4 100644 --- a/path_config.go +++ b/path_config.go @@ -50,6 +50,24 @@ func pathConfig(b *azureAuthBackend) *framework.Path { Description: "The TTL of the root password in Azure. This can be either a number of seconds or a time formatted duration (ex: 24h, 48ds)", Required: false, }, + "max_retries": { + Type: framework.TypeInt, + Default: defaultMaxRetries, + Description: "The maximum number of attempts a failed operation will be retried before producing an error.", + Required: false, + }, + "max_retry_delay": { + Type: framework.TypeSignedDurationSecond, + Default: defaultMaxRetryDelay, + Description: "The maximum delay allowed before retrying an operation.", + Required: false, + }, + "retry_delay": { + Type: framework.TypeSignedDurationSecond, + Default: defaultRetryDelay, + Description: "The initial amount of delay to use before retrying an operation, increasing exponentially.", + Required: false, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ @@ -101,6 +119,9 @@ type azureConfig struct { NewClientSecretKeyID string `json:"new_client_secret_key_id"` RootPasswordTTL time.Duration `json:"root_password_ttl"` RootPasswordExpirationDate time.Time `json:"root_password_expiration_date"` + MaxRetries int32 `json:"max_retries"` + MaxRetryDelay time.Duration `json:"max_retry_delay"` + RetryDelay time.Duration `json:"retry_delay"` } func (b *azureAuthBackend) config(ctx context.Context, s logical.Storage) (*azureConfig, error) { @@ -167,6 +188,24 @@ func (b *azureAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Req config.RootPasswordTTL = time.Second * time.Duration(rootExpirationRaw.(int)) } + config.MaxRetries = defaultMaxRetries + maxRetriesRaw, ok := data.GetOk("max_retries") + if ok { + config.MaxRetries = int32(maxRetriesRaw.(int)) + } + + config.MaxRetryDelay = defaultMaxRetryDelay + maxRetryDelayRaw, ok := data.GetOk("max_retry_delay") + if ok { + config.MaxRetryDelay = time.Second * time.Duration(maxRetryDelayRaw.(int)) + } + + config.RetryDelay = defaultRetryDelay + retryDelayRaw, ok := data.GetOk("retry_delay") + if ok { + config.RetryDelay = time.Second * time.Duration(retryDelayRaw.(int)) + } + // Create a settings object to validate all required settings // are available if _, err := b.getAzureSettings(ctx, config); err != nil { @@ -203,6 +242,9 @@ func (b *azureAuthBackend) pathConfigRead(ctx context.Context, req *logical.Requ "environment": config.Environment, "client_id": config.ClientID, "root_password_ttl": int(config.RootPasswordTTL.Seconds()), + "retry_delay": config.RetryDelay, + "max_retry_delay": config.MaxRetryDelay, + "max_retries": config.MaxRetries, }, } @@ -246,6 +288,9 @@ const ( // the Azure UI, so we're setting it to 6 months (in hours) // as the default. defaultRootPasswordTTL = 4380 * time.Hour + defaultRetryDelay = 4 * time.Second + defaultMaxRetries = int32(3) + defaultMaxRetryDelay = 60 * time.Second configStoragePath = "config" confHelpSyn = `Configures the Azure authentication backend.` confHelpDesc = ` diff --git a/path_config_test.go b/path_config_test.go index 134291c3..f5dfb998 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -6,6 +6,7 @@ package azureauth import ( "context" "testing" + "time" "github.com/hashicorp/vault/sdk/logical" ) @@ -87,3 +88,122 @@ func testConfigCreate(t *testing.T, b *azureAuthBackend, s logical.Storage, d ma } return nil } + +func testConfigRead(t *testing.T, b *azureAuthBackend, s logical.Storage) (*logical.Response, error) { + t.Helper() + return b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Path: "config", + Storage: s, + }) +} + +func TestConfig_RetryDefaults(t *testing.T) { + b, s := getTestBackend(t) + + configData := map[string]interface{}{ + "tenant_id": "tid", + "resource": "resource", + } + + if err := testConfigCreate(t, b, s, configData); err != nil { + t.Fatalf("err: %v", err) + } + + resp, err := testConfigRead(t, b, s) + if err != nil { + t.Fatalf("err: %v", err) + } + + if resp.Data["max_retries"] != defaultMaxRetries { + t.Fatalf("wrong max_retries default: expected %v, got %v", defaultMaxRetries, resp.Data["max_retries"]) + } + + if resp.Data["max_retry_delay"] != defaultMaxRetryDelay { + t.Fatalf("wrong 'max_retry_delay' default: expected %v, got %v", defaultMaxRetryDelay, resp.Data["max_retry_delay"]) + } + + if resp.Data["retry_delay"] != defaultRetryDelay { + t.Fatalf("wrong 'retry_delay' default: expected %v, got %v", defaultRetryDelay, resp.Data["retry_delay"]) + } + + config, err := b.config(context.Background(), s) + if err != nil { + t.Fatalf("err: %v", err) + } + + azureSettings, err := b.getAzureSettings(context.Background(), config) + if err != nil { + t.Fatalf("err: %v", err) + } + + if azureSettings.MaxRetries != defaultMaxRetries { + t.Fatalf("wrong 'max_retries' default azure settings value: expected %v, got %v", defaultMaxRetries, azureSettings.MaxRetries) + } + + if azureSettings.MaxRetryDelay != defaultMaxRetryDelay { + t.Fatalf("wrong 'max_retry_delay' default azure settings value: expected %v, got %v", defaultMaxRetryDelay, azureSettings.MaxRetryDelay) + } + + if azureSettings.RetryDelay != defaultRetryDelay { + t.Fatalf("wrong 'retry_delay' default azure settings value: expected %v, got %v", defaultRetryDelay, azureSettings.RetryDelay) + } +} + +func TestConfig_RetryCustom(t *testing.T) { + b, s := getTestBackend(t) + maxRetries := int32(60) + maxRetryDelay := time.Second * 120 + retryDelay := time.Second * 10 + + configData := map[string]interface{}{ + "tenant_id": "tid", + "resource": "resource", + "max_retries": maxRetries, + "max_retry_delay": maxRetryDelay, + "retry_delay": retryDelay, + } + + if err := testConfigCreate(t, b, s, configData); err != nil { + t.Fatalf("err: %v", err) + } + + resp, err := testConfigRead(t, b, s) + if err != nil { + t.Fatalf("err: %v", err) + } + + if resp.Data["max_retries"] != maxRetries { + t.Fatalf("wrong max_retries value: expected %v, got %v", maxRetries, resp.Data["max_retries"]) + } + + if resp.Data["max_retry_delay"] != maxRetryDelay { + t.Fatalf("wrong 'max_retry_delay' value: expected %v, got %v", maxRetryDelay, resp.Data["max_retry_delay"]) + } + + if resp.Data["retry_delay"] != retryDelay { + t.Fatalf("wrong 'retry_delay' value: expected %v, got %v", retryDelay, resp.Data["retry_delay"]) + } + + config, err := b.config(context.Background(), s) + if err != nil { + t.Fatalf("err: %v", err) + } + + azureSettings, err := b.getAzureSettings(context.Background(), config) + if err != nil { + t.Fatalf("err: %v", err) + } + + if azureSettings.MaxRetries != maxRetries { + t.Fatalf("wrong 'max_retries' azure settings value: expected %v, got %v", maxRetries, azureSettings.MaxRetries) + } + + if azureSettings.MaxRetryDelay != maxRetryDelay { + t.Fatalf("wrong 'max_retry_delay' azure settings value: expected %v, got %v", maxRetryDelay, azureSettings.MaxRetryDelay) + } + + if azureSettings.RetryDelay != retryDelay { + t.Fatalf("wrong 'retry_delay' azure settings value: expected %v, got %v", retryDelay, azureSettings.RetryDelay) + } +}