Skip to content

Commit

Permalink
fix(vault): Set the Vault TLS client options if specified
Browse files Browse the repository at this point in the history
Two Vault client settings were not being properly used when constructing
a Vault client.

The `TLS Skip Verify` setting was only being set if a `CA Cert` was also
configured. This fix sets the `TLS Skip Verify` when configured
regardless of other settings. Closes #3961.

The `TLS Server Name` setting was never being set. Bad programmers. This
fix now sets it on the Vault client if the Vault Credential Store has
been configured to use a value for this setting. Closes #3962.
  • Loading branch information
mgaffney committed Nov 2, 2023
1 parent a9e3729 commit d638378
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 41 deletions.
4 changes: 3 additions & 1 deletion internal/credential/vault/supported.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func gotNewServer(t testing.TB, opt ...TestOption) *TestVaultServer {
dockerOptions.Env = append(dockerOptions.Env, fmt.Sprintf("VAULT_LOCAL_CONFIG=%s", clientTlsTemplate))
}

serverCert := testServerCert(t, testCaCert(t), "localhost")
serverCert := testServerCert(t, testCaCert(t), opt...)
server.serverCertBundle = serverCert
server.ServerCert = serverCert.Cert.Cert
server.CaCert = serverCert.CA.Cert
Expand Down Expand Up @@ -147,6 +147,8 @@ func gotNewServer(t testing.TB, opt ...TestOption) *TestVaultServer {
return &vaultClientCert, nil
}
}
server.TlsSkipVerify = true
clientTLSConfig.InsecureSkipVerify = true
}

// NOTE(mgaffney) 05/2021: creating a docker network is not the default
Expand Down
83 changes: 50 additions & 33 deletions internal/credential/vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func testCaCert(t testing.TB) *testCert {
}

// testServerCert will generate a test x509 server cert.
func testServerCert(t testing.TB, ca *testCert, hosts ...string) *testCertBundle {
func testServerCert(t testing.TB, ca *testCert, opt ...TestOption) *testCertBundle {
t.Helper()
require := require.New(t)

Expand Down Expand Up @@ -404,7 +404,8 @@ func testServerCert(t testing.TB, ca *testCert, hosts ...string) *testCertBundle
BasicConstraintsValid: true,
}

for _, h := range hosts {
opts := getTestOpts(t, opt...)
for _, h := range opts.serverCertHostNames {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
Expand Down Expand Up @@ -512,19 +513,20 @@ type TestOption func(testing.TB, *testOptions)

// options = how options are represented
type testOptions struct {
orphan bool
periodic bool
renewable bool
policies []string
allowedExtensions []string
mountPath string
roleName string
vaultTLS TestVaultTLS
dockerNetwork bool
skipCleanup bool
tokenPeriod time.Duration
clientKey *ecdsa.PrivateKey
vaultVersion string
orphan bool
periodic bool
renewable bool
policies []string
allowedExtensions []string
mountPath string
roleName string
vaultTLS TestVaultTLS
dockerNetwork bool
skipCleanup bool
tokenPeriod time.Duration
clientKey *ecdsa.PrivateKey
vaultVersion string
serverCertHostNames []string
}

func getDefaultTestOptions(t testing.TB) testOptions {
Expand All @@ -537,15 +539,16 @@ func getDefaultTestOptions(t testing.TB) testOptions {
}

return testOptions{
orphan: true,
periodic: true,
renewable: true,
policies: []string{"default", "boundary-controller"},
mountPath: "",
roleName: "boundary",
vaultTLS: TestNoTLS,
dockerNetwork: false,
tokenPeriod: defaultPeriod,
orphan: true,
periodic: true,
renewable: true,
policies: []string{"default", "boundary-controller"},
mountPath: "",
roleName: "boundary",
vaultTLS: TestNoTLS,
dockerNetwork: false,
tokenPeriod: defaultPeriod,
serverCertHostNames: []string{"localhost"},
}
}

Expand Down Expand Up @@ -587,6 +590,16 @@ func WithTestVaultTLS(s TestVaultTLS) TestOption {
}
}

// WithServerCertHostNames sets the host names or IP address to attach to
// the test server's TLS certificate. The default host name attached to the
// test server's TLS certificate is 'localhost'.
func WithServerCertHostNames(h []string) TestOption {
return func(t testing.TB, o *testOptions) {
t.Helper()
o.serverCertHostNames = h
}
}

// WithClientKey sets the private key that will be used to generate the
// client certificate. The option is only valid when used together with TestClientTLS.
func WithClientKey(k *ecdsa.PrivateKey) TestOption {
Expand Down Expand Up @@ -699,11 +712,13 @@ func (v *TestVaultServer) ClientUsingToken(t testing.TB, token string) *client {
ctx := context.Background()
require := require.New(t)
conf := &clientConfig{
Addr: v.Addr,
Token: TokenSecret(token),
CaCert: v.CaCert,
ClientCert: v.ClientCert,
ClientKey: v.ClientKey,
Addr: v.Addr,
Token: TokenSecret(token),
CaCert: v.CaCert,
ClientCert: v.ClientCert,
ClientKey: v.ClientKey,
TlsServerName: v.TlsServerName,
TlsSkipVerify: v.TlsSkipVerify,
}

client, err := newClient(ctx, conf)
Expand Down Expand Up @@ -1008,10 +1023,12 @@ type TestVaultServer struct {
RootToken string
Addr string

CaCert []byte
ServerCert []byte
ClientCert []byte
ClientKey []byte
CaCert []byte
ServerCert []byte
ClientCert []byte
ClientKey []byte
TlsServerName string
TlsSkipVerify bool

serverCertBundle *testCertBundle
clientCertBundle *testCertBundle
Expand Down
64 changes: 59 additions & 5 deletions internal/credential/vault/testing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,58 @@ func TestNewVaultServer(t *testing.T) {
require.NotNil(client)
require.NoError(client.ping(ctx))
})
t.Run("TestServerTLS-InsecureSkipVerify", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
v := NewTestVaultServer(t, WithTestVaultTLS(TestServerTLS), WithServerCertHostNames([]string{"kaz"}))
require.NotNil(v)

assert.NotEmpty(v.RootToken)
assert.NotEmpty(v.Addr)
assert.NotEmpty(v.CaCert)

conf := &clientConfig{
Addr: v.Addr,
Token: TokenSecret(v.RootToken),
CaCert: v.CaCert,
}

client, err := newClient(ctx, conf)
require.NoError(err)
require.NotNil(client)
require.Error(client.ping(ctx))

conf.TlsSkipVerify = true
client, err = newClient(ctx, conf)
require.NoError(err)
require.NotNil(client)
require.NoError(client.ping(ctx))
})
t.Run("TestServerTLS-TlsServerName", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
v := NewTestVaultServer(t, WithTestVaultTLS(TestServerTLS), WithServerCertHostNames([]string{"kaz"}))
require.NotNil(v)

assert.NotEmpty(v.RootToken)
assert.NotEmpty(v.Addr)
assert.NotEmpty(v.CaCert)

conf := &clientConfig{
Addr: v.Addr,
Token: TokenSecret(v.RootToken),
CaCert: v.CaCert,
}

client, err := newClient(ctx, conf)
require.NoError(err)
require.NotNil(client)
require.Error(client.ping(ctx))

conf.TlsServerName = "kaz"
client, err = newClient(ctx, conf)
require.NoError(err)
require.NotNil(client)
require.NoError(client.ping(ctx))
})
t.Run("TestClientTLS", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
v := NewTestVaultServer(t, WithTestVaultTLS(TestClientTLS))
Expand Down Expand Up @@ -311,11 +363,13 @@ func TestNewVaultServer(t *testing.T) {
assert.Equal(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: k}), v.ClientKey)

conf := &clientConfig{
Addr: v.Addr,
Token: TokenSecret(v.RootToken),
CaCert: v.CaCert,
ClientCert: v.ClientCert,
ClientKey: v.ClientKey,
Addr: v.Addr,
Token: TokenSecret(v.RootToken),
CaCert: v.CaCert,
TlsServerName: v.TlsServerName,
TlsSkipVerify: v.TlsSkipVerify,
ClientCert: v.ClientCert,
ClientKey: v.ClientKey,
}

client, err := newClient(ctx, conf)
Expand Down
8 changes: 6 additions & 2 deletions internal/credential/vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ func newClient(ctx context.Context, c *clientConfig) (*client, error) {
}
vc := vault.DefaultConfig()
vc.Address = c.Addr
tlsConfig := vc.HttpClient.Transport.(*http.Transport).TLSClientConfig
tlsConfig.InsecureSkipVerify = c.TlsSkipVerify
if c.TlsServerName != "" {
tlsConfig.ServerName = c.TlsServerName
}

if len(c.CaCert) > 0 {
rootConfig := &rootcerts.Config{
CACertificate: c.CaCert,
}
tlsConfig := vc.HttpClient.Transport.(*http.Transport).TLSClientConfig
tlsConfig.InsecureSkipVerify = c.TlsSkipVerify
if err := rootcerts.ConfigureTLS(tlsConfig, rootConfig); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
Expand Down

0 comments on commit d638378

Please sign in to comment.