diff --git a/cli.go b/cli.go index 76474a880..900822b5b 100644 --- a/cli.go +++ b/cli.go @@ -505,6 +505,11 @@ func (cli *CLI) ParseFlags(args []string) ( return nil }), "vault-retry-max-backoff", "") + flags.Var((funcVar)(func(s string) error { + c.Vault.ClientUserAgent = config.String(s) + return nil + }), "vault-client-user-agent", "") + flags.Var((funcBoolVar)(func(b bool) error { c.Vault.SSL.Enabled = config.Bool(b) return nil diff --git a/config/config_test.go b/config/config_test.go index 3e798d40e..556c3c208 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1343,6 +1343,18 @@ func TestParse(t *testing.T) { }, false, }, + { + "vault_user_agent", + `vault { + client_user_agent = "my-user-agent" + }`, + &Config{ + Vault: &VaultConfig{ + ClientUserAgent: String("my-user-agent"), + }, + }, + false, + }, { "vault_token", `vault { diff --git a/config/vault.go b/config/vault.go index 332dac50e..db427c0a8 100644 --- a/config/vault.go +++ b/config/vault.go @@ -86,6 +86,10 @@ type VaultConfig struct { // UnwrapToken unwraps the provided Vault token as a wrapped token. UnwrapToken *bool `mapstructure:"unwrap_token"` + // ClientUserAgent is the User-Agent header that will be set on the client + // when making requests to Vault. + ClientUserAgent *string `mapstructure:"client_user_agent""` + // DefaultLeaseDuration configures the default lease duration when not explicitly // set by vault DefaultLeaseDuration *time.Duration `mapstructure:"default_lease_duration"` @@ -233,6 +237,10 @@ func (c *VaultConfig) Merge(o *VaultConfig) *VaultConfig { r.VaultAgentTokenFile = o.VaultAgentTokenFile } + if o.ClientUserAgent != nil { + r.ClientUserAgent = o.ClientUserAgent + } + if o.Transport != nil { r.Transport = r.Transport.Merge(o.Transport) } diff --git a/dependency/client_set.go b/dependency/client_set.go index 77f67ef34..909ba03f8 100644 --- a/dependency/client_set.go +++ b/dependency/client_set.go @@ -86,17 +86,18 @@ type CreateConsulClientInput struct { // CreateVaultClientInput is used as input to the CreateVaultClient function. type CreateVaultClientInput struct { - Address string - Namespace string - Token string - UnwrapToken bool - SSLEnabled bool - SSLVerify bool - SSLCert string - SSLKey string - SSLCACert string - SSLCAPath string - ServerName string + Address string + Namespace string + Token string + UnwrapToken bool + SSLEnabled bool + SSLVerify bool + SSLCert string + SSLKey string + SSLCACert string + SSLCAPath string + ServerName string + ClientUserAgent string K8SAuthRoleName string K8SServiceAccountTokenPath string @@ -337,6 +338,11 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error { return fmt.Errorf("client set: vault: %s", err) } + if i.ClientUserAgent != "" { + client.SetCloneHeaders(true) + client.AddHeader("User-Agent", i.ClientUserAgent) + } + // Set the namespace if given. if i.Namespace != "" { client.SetNamespace(i.Namespace) diff --git a/dependency/client_set_test.go b/dependency/client_set_test.go index 28c06e593..4cc1e7472 100644 --- a/dependency/client_set_test.go +++ b/dependency/client_set_test.go @@ -17,6 +17,8 @@ import ( "github.com/stretchr/testify/require" ) +const userAgent = "my-user-agent" + func TestClientSet_K8SServiceTokenAuth(t *testing.T) { t.Parallel() @@ -46,6 +48,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, K8SAuthRoleName: "default", K8SServiceAccountToken: "service_token", }) @@ -75,6 +78,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, K8SAuthRoleName: "default_file", K8SServiceAccountTokenPath: f.Name(), }) @@ -104,6 +108,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, K8SAuthRoleName: "default", K8SServiceAccountTokenPath: f.Name(), K8SServiceAccountToken: "service_token_value", @@ -129,6 +134,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, K8SAuthRoleName: "default", K8SServiceAccountToken: "service_token", K8SServiceMountPath: "mount_path", @@ -149,6 +155,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, Token: vaultToken, K8SAuthRoleName: "default", K8SServiceAccountToken: "service_token", @@ -172,6 +179,7 @@ func TestClientSet_K8SServiceTokenAuth(t *testing.T) { clientSet := NewClientSet() err := clientSet.CreateVaultClient(&CreateVaultClientInput{ Address: testServerAddr, + ClientUserAgent: userAgent, K8SAuthRoleName: "default", K8SServiceAccountToken: "service_token", }) @@ -192,6 +200,10 @@ func (m vaultMock) processReq(tb testing.TB, w http.ResponseWriter, r *http.Requ return } + if r.UserAgent() != userAgent { + tb.Fatalf("User-Agent header not as expected. Expected %s, got %s. Request was to %s", userAgent, r.UserAgent(), r.RequestURI) + } + var data map[string]interface{} err := json.NewDecoder(r.Body).Decode(&data) if !assert.NoError(tb, err) { diff --git a/docs/configuration.md b/docs/configuration.md index c2bb518ca..3f9092c20 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -381,6 +381,10 @@ vault { # # This value can also be specified via the environment variable VAULT_NAMESPACE. namespace = "" + + # This is an optional configuration item that, if set, will determine the + # User-Agent header to use on all requests to Vault. + client_user_agent = "Consul Template" # This is the token to use when communicating with the Vault server. # Like other tools that integrate with Vault, Consul Template makes the diff --git a/manager/runner.go b/manager/runner.go index 25cd41f16..995f5c556 100644 --- a/manager/runner.go +++ b/manager/runner.go @@ -1351,6 +1351,7 @@ func NewClientSet(c *config.Config) (*dep.ClientSet, error) { SSLCACert: config.StringVal(c.Vault.SSL.CaCert), SSLCAPath: config.StringVal(c.Vault.SSL.CaPath), ServerName: config.StringVal(c.Vault.SSL.ServerName), + ClientUserAgent: config.StringVal(c.Vault.ClientUserAgent), TransportCustomDialer: c.Vault.Transport.CustomDialer, TransportDialKeepAlive: config.TimeDurationVal(c.Vault.Transport.DialKeepAlive), TransportDialTimeout: config.TimeDurationVal(c.Vault.Transport.DialTimeout),