Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure a consistent TLS configuration #173

Merged
merged 12 commits into from
Jan 12, 2023
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ setup-integration-test: teardown-integration-test vault-image
--set server.dev.enabled=true \
--set server.image.tag=dev \
--set server.image.pullPolicy=Never \
--set server.logLevel=trace \
--set injector.enabled=false \
--set server.extraArgs="-dev-plugin-dir=/vault/plugin_directory"
kubectl patch --namespace=test statefulset vault --patch-file integrationtest/vault/hostPortPatch.yaml
Expand Down
176 changes: 155 additions & 21 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"

"github.com/hashicorp/go-cleanhttp"
Expand Down Expand Up @@ -53,6 +56,9 @@ type kubeAuthBackend struct {
// default HTTP client for connection reuse
httpClient *http.Client

// tlsConfig is periodically updated whenever the CA certificate configuration changes.
tlsConfig *tls.Config

// reviewFactory is used to configure the strategy for doing a token review.
// Currently the only options are using the kubernetes API or mocking the
// review. Mocks should only be used in tests.
Expand Down Expand Up @@ -83,10 +89,22 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
return b, nil
}

var getDefaultHTTPClient = cleanhttp.DefaultPooledClient

func defaultTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}

func Backend() *kubeAuthBackend {
b := &kubeAuthBackend{
localSATokenReader: newCachingFileReader(localJWTPath, jwtReloadPeriod, time.Now),
localCACertReader: newCachingFileReader(localCACertPath, caReloadPeriod, time.Now),
// Set default HTTP client
httpClient: getDefaultHTTPClient(),
// Set the review factory to default to calling into the kubernetes API.
reviewFactory: tokenReviewAPIFactory,
}

b.Backend = &framework.Backend{
Expand All @@ -111,41 +129,82 @@ func Backend() *kubeAuthBackend {
InitializeFunc: b.initialize,
}

// Set default HTTP client
b.httpClient = cleanhttp.DefaultPooledClient()

// Set the review factory to default to calling into the kubernetes API.
b.reviewFactory = tokenReviewAPIFactory

return b
}

// initialize is used to handle the state of config values just after the K8s plugin has been mounted
func (b *kubeAuthBackend) initialize(ctx context.Context, req *logical.InitializationRequest) error {
// Try to load the config on initialization
config, err := b.loadConfig(ctx, req.Storage)
config, err := b.config(ctx, req.Storage)
if err != nil {
benashz marked this conversation as resolved.
Show resolved Hide resolved
return err
}
if config == nil {
return nil

if config != nil {
if err := b.updateTLSConfig(config); err != nil {
return err
}
}

b.l.Lock()
defer b.l.Unlock()
// If we have a CA cert build the TLSConfig
if len(config.CACert) > 0 {
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM([]byte(config.CACert))

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: certPool,
return b.runTLSConfigUpdater(context.Background(), req.Storage)
}

// runTLSConfigUpdater sets up a routine that periodically calls b.updateTLSConfig(). This ensures that the
// httpClient's TLS configuration is consistent with the backend's stored configuration.
func (b *kubeAuthBackend) runTLSConfigUpdater(ctx context.Context, s logical.Storage) error {
updateTLSConfig := func(ctx context.Context, s logical.Storage, force bool) error {
benashz marked this conversation as resolved.
Show resolved Hide resolved
config, err := b.config(ctx, s)
benashz marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("failed config read, err=%w", err)
}

if config == nil {
b.Logger().Trace("Skipping TLSConfig update, no configuration set")
return nil
}

if err := b.updateTLSConfig(config); err != nil {
return err
}

b.httpClient.Transport.(*http.Transport).TLSClientConfig = tlsConfig
return nil
}

horizon := time.Second * 30
ticker := time.NewTicker(horizon)
wCtx, cancel := context.WithCancel(ctx)
tomhjp marked this conversation as resolved.
Show resolved Hide resolved
go func(ctx context.Context, cancel context.CancelFunc, s logical.Storage) {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGINT)
benashz marked this conversation as resolved.
Show resolved Hide resolved
defer signal.Stop(sigs)

b.Logger().Trace("Starting TLS config updater", "horizon", horizon.String())
for {
select {
case <-ctx.Done():
tomhjp marked this conversation as resolved.
Show resolved Hide resolved
b.Logger().Trace("Shutting down TLS config updater")
return
case <-ticker.C:
if err := updateTLSConfig(ctx, s, false); err != nil {
b.Logger().Warn("Retrying failed update", "horizon", horizon.String(), "err", err)
}
case sig := <-sigs:
b.Logger().Trace(fmt.Sprintf("Caught signal %v", sig))
switch sig {
case syscall.SIGHUP:
// update the TLS configuration when the plugin process receives a SIGHUP
b.Logger().Trace(fmt.Sprintf("Calling updateTLSConfig() on signal %v", sig))
if err := updateTLSConfig(ctx, s, true); err != nil {
b.Logger().Warn("Retrying failed update", "horizon", horizon.String(), "err", err)
}
default:
// shutdown on all other signals
b.Logger().Trace(fmt.Sprintf("Calling cancel() on signal %v", sig))
cancel()
}
}
}
}(wCtx, cancel, s)

return nil
}

Expand Down Expand Up @@ -255,6 +314,81 @@ func (b *kubeAuthBackend) role(ctx context.Context, s logical.Storage, name stri
return role, nil
}

// getHTTPClient return the backend's HTTP client for connecting to the Kubernetes API.
func (b *kubeAuthBackend) getHTTPClient(config *kubeConfig) (*http.Client, error) {
if b.httpClient == nil {
return nil, fmt.Errorf("the backend's http.Client has not been initialized")
}

if b.tlsConfig == nil {
benashz marked this conversation as resolved.
Show resolved Hide resolved
// ensure that HTTP client's transport TLS configuration is initialized
// this adds some belt-and-suspenders,
// since in most cases the TLS configuration would have already been initialized.
if err := b.updateTLSConfig(config); err != nil {
return nil, err
}
}

return b.httpClient, nil
}

// updateTLSConfig ensures that the httpClient's TLS configuration is consistent
// with backend's stored configuration.
func (b *kubeAuthBackend) updateTLSConfig(config *kubeConfig) error {
b.l.Lock()
defer b.l.Unlock()

if b.httpClient == nil {
return fmt.Errorf("the backend's http.Client has not been initialized")
}

// attempt to read the CA certificates the config directly or from the filesystem.
benashz marked this conversation as resolved.
Show resolved Hide resolved
var caCertBytes []byte
if config.CACert != "" {
caCertBytes = []byte(config.CACert)
} else if !config.DisableLocalCAJwt && b.localCACertReader != nil {
// TODO: this may block on I/O, investigate a proper mitigation
benashz marked this conversation as resolved.
Show resolved Hide resolved
data, err := b.localCACertReader.ReadFile()
if err != nil {
return err
}
caCertBytes = []byte(data)
}

transport, ok := b.httpClient.Transport.(*http.Transport)
if !ok {
// should never happen
return fmt.Errorf("type assertion failed for %T", b.httpClient.Transport)
}
benashz marked this conversation as resolved.
Show resolved Hide resolved

if b.tlsConfig == nil {
b.tlsConfig = defaultTLSConfig()
}
benashz marked this conversation as resolved.
Show resolved Hide resolved

certPool := x509.NewCertPool()
if len(caCertBytes) > 0 {
if ok := certPool.AppendCertsFromPEM(caCertBytes); !ok {
b.Logger().Warn("Configured CA PEM data contains no valid certificates, TLS verification will fail")
}
} else {
// provide an empty certPool
b.Logger().Warn("No CA certificates configured, TLS verification will fail")
// TODO: think about supporting host root CA certificates via a configuration toggle,
// in which case RootCAs should be set to nil
}

// only refresh the Root CAs if they have changed since the last full update.
if !b.tlsConfig.RootCAs.Equal(certPool) {
b.Logger().Trace("Root CA certificate pool has changed, updating the client's transport")
b.tlsConfig.RootCAs = certPool
transport.TLSClientConfig = b.tlsConfig
tvoran marked this conversation as resolved.
Show resolved Hide resolved
} else {
b.Logger().Trace("Root CA certificate pool is unchanged, no update required")
}

return nil
}

func validateAliasNameSource(source string) error {
for _, s := range aliasNameSources {
if s == source {
Expand Down
33 changes: 5 additions & 28 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"net/http"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -157,32 +155,6 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ
DisableLocalCAJwt: disableLocalJWT,
}

b.l.Lock()
defer b.l.Unlock()
tomhjp marked this conversation as resolved.
Show resolved Hide resolved

// Determine if we load the local CA cert or the CA cert provided
// by the kubernetes_ca_cert path into the backend's HTTP client
certPool := x509.NewCertPool()
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if disableLocalJWT || len(caCert) > 0 {
certPool.AppendCertsFromPEM([]byte(config.CACert))
tlsConfig.RootCAs = certPool

b.httpClient.Transport.(*http.Transport).TLSClientConfig = tlsConfig
} else {
localCACert, err := b.localCACertReader.ReadFile()
if err != nil {
return nil, err
}

certPool.AppendCertsFromPEM([]byte(localCACert))
tlsConfig.RootCAs = certPool

b.httpClient.Transport.(*http.Transport).TLSClientConfig = tlsConfig
}

var err error
for i, pem := range pemList {
config.PublicKeys[i], err = parsePublicKeyPEM([]byte(pem))
Expand All @@ -191,6 +163,10 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ
}
}

if err := b.updateTLSConfig(config); err != nil {
return logical.ErrorResponse(err.Error()), nil
}

entry, err := logical.StorageEntryJSON(configPath, config)
if err != nil {
return nil, err
Expand All @@ -199,6 +175,7 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

return nil, nil
}

Expand Down
7 changes: 6 additions & 1 deletion path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,13 @@ func (b *kubeAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return nil, err
}

client, err := b.getHTTPClient(config)
tomhjp marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, logical.ErrUnrecoverable
tomhjp marked this conversation as resolved.
Show resolved Hide resolved
}

// look up the JWT token in the kubernetes API
err = serviceAccount.lookup(ctx, b.httpClient, jwtStr, b.reviewFactory(config))
err = serviceAccount.lookup(ctx, client, jwtStr, b.reviewFactory(config))

if err != nil {
b.Logger().Debug(`login unauthorized`, "err", err)
Expand Down