Skip to content

Commit

Permalink
Move local cluster parameters to atomic values to fix some potential …
Browse files Browse the repository at this point in the history
…data races (#4036)
  • Loading branch information
jefferai authored Feb 23, 2018
1 parent a46b996 commit ec63c19
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 66 deletions.
89 changes: 37 additions & 52 deletions vault/cluster.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vault

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
Expand Down Expand Up @@ -91,11 +90,9 @@ func (c *Core) Cluster(ctx context.Context) (*Cluster, error) {
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
defer func() {
if retErr != nil {
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterPrivateKey = nil
c.localClusterParsedCert = nil
c.clusterParamsLock.Unlock()
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

c.requestForwardingConnectionLock.Lock()
c.clearForwardingClients()
Expand Down Expand Up @@ -126,28 +123,26 @@ func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {

}

// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()

c.localClusterPrivateKey = &ecdsa.PrivateKey{
c.localClusterPrivateKey.Store(&ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P521(),
X: adv.ClusterKeyParams.X,
Y: adv.ClusterKeyParams.Y,
},
D: adv.ClusterKeyParams.D,
}
})

c.localClusterCert = adv.ClusterCert
locCert := make([]byte, len(adv.ClusterCert))
copy(locCert, adv.ClusterCert)
c.localClusterCert.Store(locCert)

cert, err := x509.ParseCertificate(c.localClusterCert)
cert, err := x509.ParseCertificate(adv.ClusterCert)
if err != nil {
c.logger.Error("core: failed parsing local cluster certificate", "error", err)
return fmt.Errorf("error parsing local cluster certificate: %v", err)
}

c.localClusterParsedCert = cert
c.localClusterParsedCert.Store(cert)

return nil
}
Expand Down Expand Up @@ -210,19 +205,19 @@ func (c *Core) setupCluster(ctx context.Context) error {
// If we're using HA, generate server-to-server parameters
if c.ha != nil {
// Create a private key
if c.localClusterPrivateKey == nil {
if c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey) == nil {
c.logger.Trace("core: generating cluster private key")
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
c.logger.Error("core: failed to generate local cluster key", "error", err)
return err
}

c.localClusterPrivateKey = key
c.localClusterPrivateKey.Store(key)
}

// Create a certificate
if c.localClusterCert == nil {
if c.localClusterCert.Load().([]byte) == nil {
c.logger.Trace("core: generating local cluster certificate")

host, err := uuid.GenerateUUID()
Expand All @@ -248,7 +243,7 @@ func (c *Core) setupCluster(ctx context.Context) error {
IsCA: true,
}

certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey).Public(), c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey))
if err != nil {
c.logger.Error("core: error generating self-signed cert", "error", err)
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
Expand All @@ -260,8 +255,8 @@ func (c *Core) setupCluster(ctx context.Context) error {
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
}

c.localClusterCert = certBytes
c.localClusterParsedCert = parsedCert
c.localClusterCert.Store(certBytes)
c.localClusterParsedCert.Store(parsedCert)
}
}

Expand Down Expand Up @@ -349,24 +344,20 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
switch {
default:
var localCert bytes.Buffer

c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()

if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("got forwarding connection but no local cert")
}

localCert := make([]byte, len(currCert))
copy(localCert, currCert)

//c.logger.Trace("core: performing cert name lookup", "hello_server_name", clientHello.ServerName, "local_cluster_cert_name", parsedCert.Subject.CommonName)

return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
Expand All @@ -377,22 +368,19 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
if len(requestInfo.AcceptableCAs) != 1 {
return nil, fmt.Errorf("expected only a single acceptable CA")
}
var localCert bytes.Buffer

c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
localSigner := c.localClusterPrivateKey
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()

if localCert.Len() == 0 {
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}

localCert := make([]byte, len(currCert))
copy(localCert, currCert)

return &tls.Certificate{
Certificate: [][]byte{localCert.Bytes()},
PrivateKey: localSigner,
Leaf: parsedCert,
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}

Expand Down Expand Up @@ -421,9 +409,7 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus

switch {
default:
c.clusterParamsLock.RLock()
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)

if parsedCert == nil {
return nil, fmt.Errorf("forwarding connection client but no local cert")
Expand All @@ -444,11 +430,10 @@ func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClus
CipherSuites: c.clusterCipherSuites,
}

var localCert bytes.Buffer
c.clusterParamsLock.RLock()
localCert.Write(c.localClusterCert)
parsedCert := c.localClusterParsedCert
c.clusterParamsLock.RUnlock()
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
currCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(currCert))
copy(localCert, currCert)

if parsedCert != nil {
tlsConfig.ServerName = parsedCert.Subject.CommonName
Expand Down
34 changes: 20 additions & 14 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package vault

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
Expand Down Expand Up @@ -307,11 +306,11 @@ type Core struct {
clusterParamsLock sync.RWMutex
// The private key stored in the barrier used for establishing
// mutually-authenticated connections between Vault cluster members
localClusterPrivateKey crypto.Signer
localClusterPrivateKey *atomic.Value
// The local cluster cert
localClusterCert []byte
localClusterCert *atomic.Value
// The parsed form of the local cluster cert
localClusterParsedCert *x509.Certificate
localClusterParsedCert *atomic.Value
// The TCP addresses we should use for clustering
clusterListenerAddrs []*net.TCPAddr
// The handler to use for request forwarding
Expand Down Expand Up @@ -497,10 +496,16 @@ func NewCore(conf *CoreConfig) (*Core, error) {
rpcServerActive: new(uint32),
atomicPrimaryClusterAddrs: new(atomic.Value),
atomicPrimaryFailoverAddrs: new(atomic.Value),
localClusterPrivateKey: new(atomic.Value),
localClusterCert: new(atomic.Value),
localClusterParsedCert: new(atomic.Value),
activeNodeReplicationState: new(uint32),
}

atomic.StoreUint32(c.replicationState, uint32(consts.ReplicationDRDisabled|consts.ReplicationPerformanceDisabled))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

if conf.ClusterCipherSuites != "" {
suites, err := tlsutil.ParseCiphers(conf.ClusterCipherSuites)
Expand Down Expand Up @@ -1816,11 +1821,9 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {

// Clear previous local cluster cert info so we generate new. Since the
// UUID will have changed, standbys will know to look for new info
c.clusterParamsLock.Lock()
c.localClusterCert = nil
c.localClusterParsedCert = nil
c.localClusterPrivateKey = nil
c.clusterParamsLock.Unlock()
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterCert.Store(([]byte)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))

if err := c.setupCluster(ctx); err != nil {
c.stateLock.Unlock()
Expand Down Expand Up @@ -2049,12 +2052,12 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
go c.cleanLeaderPrefix(ctx, uuid, leaderLostCh)

var key *ecdsa.PrivateKey
switch c.localClusterPrivateKey.(type) {
switch c.localClusterPrivateKey.Load().(type) {
case *ecdsa.PrivateKey:
key = c.localClusterPrivateKey.(*ecdsa.PrivateKey)
key = c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey)
default:
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
c.logger.Error("core: unknown cluster private key type", "key_type", fmt.Sprintf("%T", c.localClusterPrivateKey.Load()))
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
}

keyParams := &clusterKeyParams{
Expand All @@ -2064,10 +2067,13 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
D: key.D,
}

locCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(locCert))
copy(localCert, locCert)
adv := &activeAdvertisement{
RedirectAddr: c.redirectAddr,
ClusterAddr: c.clusterAddr,
ClusterCert: c.localClusterCert,
ClusterCert: localCert,
ClusterKeyParams: keyParams,
}
val, err := jsonutil.EncodeJSON(adv)
Expand Down

0 comments on commit ec63c19

Please sign in to comment.