diff --git a/util/client.go b/util/client.go index 3eb18d891..0bf1bbcc3 100644 --- a/util/client.go +++ b/util/client.go @@ -8,6 +8,7 @@ import ( "os" "strings" "sync" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" @@ -21,78 +22,89 @@ import ( type Client struct { *api.Client - Data *schema.ResourceData - mutex sync.Mutex + Data *schema.ResourceData + expireTime *time.Time + mutex sync.Mutex } func (c *Client) lazyInit() error { c.mutex.Lock() defer c.mutex.Unlock() + client := c.Client + d := c.Data + // client has already be initialized. if c.Client != nil { - return nil - } - - d := c.Data + if c.expireTime != nil && c.expireTime.After(time.Now()) { + log.Printf("[DEBUG] Vault client as already been initialized") + return nil + } - clientConfig := api.DefaultConfig() - addr := d.Get("address").(string) - if addr != "" { - clientConfig.Address = addr + client.ClearToken() } - clientAuthI := d.Get("client_auth").([]interface{}) - if len(clientAuthI) > 1 { - return fmt.Errorf("client_auth block may appear only once") - } + if client == nil { + clientConfig := api.DefaultConfig() + addr := d.Get("address").(string) + if addr != "" { + clientConfig.Address = addr + } - clientAuthCert := "" - clientAuthKey := "" - if len(clientAuthI) == 1 { - clientAuth := clientAuthI[0].(map[string]interface{}) - clientAuthCert = clientAuth["cert_file"].(string) - clientAuthKey = clientAuth["key_file"].(string) - } + clientAuthI := d.Get("client_auth").([]interface{}) + if len(clientAuthI) > 1 { + return fmt.Errorf("client_auth block may appear only once") + } - err := clientConfig.ConfigureTLS(&api.TLSConfig{ - CACert: d.Get("ca_cert_file").(string), - CAPath: d.Get("ca_cert_dir").(string), - Insecure: d.Get("skip_tls_verify").(bool), + clientAuthCert := "" + clientAuthKey := "" + if len(clientAuthI) == 1 { + clientAuth := clientAuthI[0].(map[string]interface{}) + clientAuthCert = clientAuth["cert_file"].(string) + clientAuthKey = clientAuth["key_file"].(string) + } - ClientCert: clientAuthCert, - ClientKey: clientAuthKey, - }) - if err != nil { - return fmt.Errorf("failed to configure TLS for Vault API: %s", err) - } + err := clientConfig.ConfigureTLS(&api.TLSConfig{ + CACert: d.Get("ca_cert_file").(string), + CAPath: d.Get("ca_cert_dir").(string), + Insecure: d.Get("skip_tls_verify").(bool), - clientConfig.HttpClient.Transport = logging.NewTransport("Vault", clientConfig.HttpClient.Transport) + ClientCert: clientAuthCert, + ClientKey: clientAuthKey, + }) + if err != nil { + return fmt.Errorf("failed to configure TLS for Vault API: %s", err) + } - client, err := api.NewClient(clientConfig) - if err != nil { - return fmt.Errorf("failed to configure Vault API: %s", err) - } + clientConfig.HttpClient.Transport = logging.NewTransport("Vault", clientConfig.HttpClient.Transport) - client.SetCloneHeaders(true) + client, err = api.NewClient(clientConfig) + if err != nil { + return fmt.Errorf("failed to configure Vault API: %s", err) + } - // Set headers if provided - headers := d.Get("headers").([]interface{}) - parsedHeaders := client.Headers().Clone() + client.SetCloneHeaders(true) - if parsedHeaders == nil { - parsedHeaders = make(http.Header) - } + // Set headers if provided + headers := d.Get("headers").([]interface{}) + parsedHeaders := client.Headers().Clone() - for _, h := range headers { - header := h.(map[string]interface{}) - if name, ok := header["name"]; ok { - parsedHeaders.Add(name.(string), header["value"].(string)) + if parsedHeaders == nil { + parsedHeaders = make(http.Header) } - } - client.SetHeaders(parsedHeaders) - client.SetMaxRetries(d.Get("max_retries").(int)) + for _, h := range headers { + header := h.(map[string]interface{}) + if name, ok := header["name"]; ok { + parsedHeaders.Add(name.(string), header["value"].(string)) + } + } + client.SetHeaders(parsedHeaders) + + client.SetMaxRetries(d.Get("max_retries").(int)) + + c.Client = client + } // Try an get the token from the config or token helper token, err := ProviderToken(d) @@ -167,6 +179,8 @@ func (c *Client) lazyInit() error { } } + expireTime := time.Now() + renewable := false childTokenLease, err := client.Auth().Token().Create(&api.TokenCreateRequest{ DisplayName: tokenName, @@ -181,6 +195,10 @@ func (c *Client) lazyInit() error { childToken := childTokenLease.Auth.ClientToken policies := childTokenLease.Auth.Policies + // store the expiration time of the token. + expireTime = expireTime.Add(time.Duration(childTokenLease.Auth.LeaseDuration) * time.Second) + c.expireTime = &expireTime + log.Printf("[INFO] Using Vault token with the following policies: %s", strings.Join(policies, ", ")) // Set the token to the generated child token @@ -192,9 +210,6 @@ func (c *Client) lazyInit() error { client.SetNamespace(namespace) } - c.Client = client - c.Data = nil - return nil } @@ -304,7 +319,6 @@ func SignAWSLogin(parameters map[string]interface{}) error { Config: *config, AssumeRoleTokenProvider: stscreds.StdinTokenProvider, }) - if err != nil { log.Fatalf("session error: %v", err) }