Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move local cluster parameters to atomic values to fix some potential locking deadlocks #4036

Merged
merged 1 commit into from
Feb 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 37 additions & 52 deletions vault/cluster.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vault

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
Expand Down Expand Up @@ -87,11 +86,9 @@ func (c *Core) Cluster(ctx context.Context) (*Cluster, error) {
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
defer func() {
if retErr != nil {
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterPrivateKey = nil
c.localClusterParsedCert = nil
c.clusterParamsLock.Unlock()
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

c.requestForwardingConnectionLock.Lock()
c.clearForwardingClients()
Expand Down Expand Up @@ -122,28 +119,26 @@ func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {

}

// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()

c.localClusterPrivateKey = &ecdsa.PrivateKey{
c.localClusterPrivateKey.Store(&ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P521(),
X: adv.ClusterKeyParams.X,
Y: adv.ClusterKeyParams.Y,
},
D: adv.ClusterKeyParams.D,
}
})

c.localClusterCert = adv.ClusterCert
locCert := make([]byte, len(adv.ClusterCert))
copy(locCert, adv.ClusterCert)
c.localClusterCert.Store(locCert)

cert, err := x509.ParseCertificate(c.localClusterCert)
cert, err := x509.ParseCertificate(adv.ClusterCert)
if err != nil {
c.logger.Error("core: failed parsing local cluster certificate", "error", err)
return fmt.Errorf("error parsing local cluster certificate: %v", err)
}

c.localClusterParsedCert = cert
c.localClusterParsedCert.Store(cert)

return nil
}
Expand Down Expand Up @@ -206,19 +201,19 @@ func (c *Core) setupCluster(ctx context.Context) error {
// If we're using HA, generate server-to-server parameters
if c.ha != nil {
// Create a private key
if c.localClusterPrivateKey == nil {
if c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey) == nil {
c.logger.Trace("core: generating cluster private key")
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
c.logger.Error("core: failed to generate local cluster key", "error", err)
return err
}

c.localClusterPrivateKey = key
c.localClusterPrivateKey.Store(key)
}

// Create a certificate
if c.localClusterCert == nil {
if c.localClusterCert.Load().([]byte) == nil {
c.logger.Trace("core: generating local cluster certificate")

host, err := uuid.GenerateUUID()
Expand All @@ -244,7 +239,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
IsCA: true,
}

certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey).Public(), c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey))
if err != nil {
c.logger.Error("core: error generating self-signed cert", "error", err)
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
Expand All @@ -256,8 +251,8 @@ func (c *Core) setupCluster(ctx context.Context) error {
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
}

c.localClusterCert = certBytes
c.localClusterParsedCert = parsedCert
c.localClusterCert.Store(certBytes)
c.localClusterParsedCert.Store(parsedCert)
}
}

Expand Down Expand Up @@ -345,24 +340,20 @@ func (c *Core) ClusterTLSConfig(ctx context.Context) (*tls.Config, error) {
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
switch {
default:
var localCert bytes.Buffer

c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()

if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("got forwarding connection but no local cert")
}

localCert := make([]byte, len(currCert))
copy(localCert, currCert)

//c.logger.Trace("core: performing cert name lookup", "hello_server_name", clientHello.ServerName, "local_cluster_cert_name", parsedCert.Subject.CommonName)

return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
Expand All @@ -373,22 +364,19 @@ func (c *Core) ClusterTLSConfig(ctx context.Context) (*tls.Config, error) {
if len(requestInfo.AcceptableCAs) != 1 {
return nil, fmt.Errorf("expected only a single acceptable CA")
}
var localCert bytes.Buffer

c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()

if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}

localCert := make([]byte, len(currCert))
copy(localCert, currCert)

return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}

Expand Down Expand Up @@ -417,9 +405,7 @@ func (c *Core) ClusterTLSConfig(ctx context.Context) (*tls.Config, error) {

switch {
default:
c.clusterParamsLock.RLock()
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)

if parsedCert == nil {
return nil, fmt.Errorf("forwarding connection client but no local cert")
Expand All @@ -440,11 +426,10 @@ func (c *Core) ClusterTLSConfig(ctx context.Context) (*tls.Config, error) {
CipherSuites: c.clusterCipherSuites,
}

var localCert bytes.Buffer
c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
currCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(currCert))
copy(localCert, currCert)

if parsedCert != nil {
tlsConfig.ServerName = parsedCert.Subject.CommonName
Expand Down
34 changes: 20 additions & 14 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package vault

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
Expand Down Expand Up @@ -307,11 +306,11 @@ type Core struct {
clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members
localClusterPrivateKey crypto.Signer
localClusterPrivateKey *atomic.Value
// The local cluster cert
localClusterCert []byte
localClusterCert *atomic.Value
// The parsed form of the local cluster cert
localClusterParsedCert *x509.Certificate
localClusterParsedCert *atomic.Value
// The TCP addresses we should use for clustering
clusterListenerAddrs []*net.TCPAddr
// The handler to use for request forwarding
Expand Down Expand Up @@ -497,10 +496,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
rpcServerActive: new(uint32),
atomicPrimaryClusterAddrs: new(atomic.Value),
atomicPrimaryFailoverAddrs: new(atomic.Value),
localClusterPrivateKey: new(atomic.Value),
localClusterCert: new(atomic.Value),
localClusterParsedCert: new(atomic.Value),
activeNodeReplicationState: new(uint32),
}

atomic.StoreUint32(c.replicationState, uint32(consts.ReplicationDRDisabled|consts.ReplicationPerformanceDisabled))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

if conf.ClusterCipherSuites != "" {
suites, err := tlsutil.ParseCiphers(conf.ClusterCipherSuites)
Expand Down Expand Up @@ -1816,11 +1821,9 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {

// Clear previous local cluster cert info so we generate new. Since the
// UUID will have changed, standbys will know to look for new info
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterParsedCert = nil
c.localClusterPrivateKey = nil
c.clusterParamsLock.Unlock()
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

if err := c.setupCluster(ctx); err != nil {
c.stateLock.Unlock()
Expand Down Expand Up @@ -2049,12 +2052,12 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
go c.cleanLeaderPrefix(ctx, uuid, leaderLostCh)

var key *ecdsa.PrivateKey
switch c.localClusterPrivateKey.(type) {
switch c.localClusterPrivateKey.Load().(type) {
case *ecdsa.PrivateKey:
key = c.localClusterPrivateKey.(*ecdsa.PrivateKey)
key = c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey)
default:
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey.Load()))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
}

keyParams := &clusterKeyParams{
Expand All @@ -2064,10 +2067,13 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
D: key.D,
}

locCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(locCert))
copy(localCert, locCert)
adv := &activeAdvertisement{
RedirectAddr: c.redirectAddr,
ClusterAddr: c.clusterAddr,
ClusterCert: c.localClusterCert,
ClusterCert: localCert,
ClusterKeyParams: keyParams,
}
val, err := jsonutil.EncodeJSON(adv)
Expand Down