Skip to content

Commit

Permalink
Merge pull request hashicorp#1 from greut/fix/token-expiration
Browse files Browse the repository at this point in the history
fix: handle Token expiration
  • Loading branch information
cyrilgdn authored Aug 12, 2021
2 parents 2c7968a + 971cf59 commit 5869370
Showing 1 changed file with 68 additions and 54 deletions.
122 changes: 68 additions & 54 deletions util/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
Expand All @@ -21,78 +22,89 @@ import (

type Client struct {
*api.Client
Data *schema.ResourceData
mutex sync.Mutex
Data *schema.ResourceData
expireTime *time.Time
mutex sync.Mutex
}

func (c *Client) lazyInit() error {
c.mutex.Lock()
defer c.mutex.Unlock()

client := c.Client
d := c.Data

// client has already be initialized.
if c.Client != nil {
return nil
}

d := c.Data
if c.expireTime != nil && c.expireTime.After(time.Now()) {
log.Printf("[DEBUG] Vault client as already been initialized")
return nil
}

clientConfig := api.DefaultConfig()
addr := d.Get("address").(string)
if addr != "" {
clientConfig.Address = addr
client.ClearToken()
}

clientAuthI := d.Get("client_auth").([]interface{})
if len(clientAuthI) > 1 {
return fmt.Errorf("client_auth block may appear only once")
}
if client == nil {
clientConfig := api.DefaultConfig()
addr := d.Get("address").(string)
if addr != "" {
clientConfig.Address = addr
}

clientAuthCert := ""
clientAuthKey := ""
if len(clientAuthI) == 1 {
clientAuth := clientAuthI[0].(map[string]interface{})
clientAuthCert = clientAuth["cert_file"].(string)
clientAuthKey = clientAuth["key_file"].(string)
}
clientAuthI := d.Get("client_auth").([]interface{})
if len(clientAuthI) > 1 {
return fmt.Errorf("client_auth block may appear only once")
}

err := clientConfig.ConfigureTLS(&api.TLSConfig{
CACert: d.Get("ca_cert_file").(string),
CAPath: d.Get("ca_cert_dir").(string),
Insecure: d.Get("skip_tls_verify").(bool),
clientAuthCert := ""
clientAuthKey := ""
if len(clientAuthI) == 1 {
clientAuth := clientAuthI[0].(map[string]interface{})
clientAuthCert = clientAuth["cert_file"].(string)
clientAuthKey = clientAuth["key_file"].(string)
}

ClientCert: clientAuthCert,
ClientKey: clientAuthKey,
})
if err != nil {
return fmt.Errorf("failed to configure TLS for Vault API: %s", err)
}
err := clientConfig.ConfigureTLS(&api.TLSConfig{
CACert: d.Get("ca_cert_file").(string),
CAPath: d.Get("ca_cert_dir").(string),
Insecure: d.Get("skip_tls_verify").(bool),

clientConfig.HttpClient.Transport = logging.NewTransport("Vault", clientConfig.HttpClient.Transport)
ClientCert: clientAuthCert,
ClientKey: clientAuthKey,
})
if err != nil {
return fmt.Errorf("failed to configure TLS for Vault API: %s", err)
}

client, err := api.NewClient(clientConfig)
if err != nil {
return fmt.Errorf("failed to configure Vault API: %s", err)
}
clientConfig.HttpClient.Transport = logging.NewTransport("Vault", clientConfig.HttpClient.Transport)

client.SetCloneHeaders(true)
client, err = api.NewClient(clientConfig)
if err != nil {
return fmt.Errorf("failed to configure Vault API: %s", err)
}

// Set headers if provided
headers := d.Get("headers").([]interface{})
parsedHeaders := client.Headers().Clone()
client.SetCloneHeaders(true)

if parsedHeaders == nil {
parsedHeaders = make(http.Header)
}
// Set headers if provided
headers := d.Get("headers").([]interface{})
parsedHeaders := client.Headers().Clone()

for _, h := range headers {
header := h.(map[string]interface{})
if name, ok := header["name"]; ok {
parsedHeaders.Add(name.(string), header["value"].(string))
if parsedHeaders == nil {
parsedHeaders = make(http.Header)
}
}
client.SetHeaders(parsedHeaders)

client.SetMaxRetries(d.Get("max_retries").(int))
for _, h := range headers {
header := h.(map[string]interface{})
if name, ok := header["name"]; ok {
parsedHeaders.Add(name.(string), header["value"].(string))
}
}
client.SetHeaders(parsedHeaders)

client.SetMaxRetries(d.Get("max_retries").(int))

c.Client = client
}

// Try an get the token from the config or token helper
token, err := ProviderToken(d)
Expand Down Expand Up @@ -167,6 +179,8 @@ func (c *Client) lazyInit() error {
}
}

expireTime := time.Now()

renewable := false
childTokenLease, err := client.Auth().Token().Create(&api.TokenCreateRequest{
DisplayName: tokenName,
Expand All @@ -181,6 +195,10 @@ func (c *Client) lazyInit() error {
childToken := childTokenLease.Auth.ClientToken
policies := childTokenLease.Auth.Policies

// store the expiration time of the token.
expireTime = expireTime.Add(time.Duration(childTokenLease.Auth.LeaseDuration) * time.Second)
c.expireTime = &expireTime

log.Printf("[INFO] Using Vault token with the following policies: %s", strings.Join(policies, ", "))

// Set the token to the generated child token
Expand All @@ -192,9 +210,6 @@ func (c *Client) lazyInit() error {
client.SetNamespace(namespace)
}

c.Client = client
c.Data = nil

return nil
}

Expand Down Expand Up @@ -304,7 +319,6 @@ func SignAWSLogin(parameters map[string]interface{}) error {
Config: *config,
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
})

if err != nil {
log.Fatalf("session error: %v", err)
}
Expand Down

0 comments on commit 5869370

Please sign in to comment.