Skip to content

Commit

Permalink
More porting from rep (#2389)
Browse files Browse the repository at this point in the history
* More porting from rep

* Address feedback
  • Loading branch information
jefferai authored Feb 17, 2017
1 parent 8acbdef commit 2901591
Show file tree
Hide file tree
Showing 8 changed files with 530 additions and 92 deletions.
81 changes: 71 additions & 10 deletions vault/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package vault

import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"strings"
Expand All @@ -26,6 +25,10 @@ const (
// can only be viewed or modified after an unseal.
coreAuditConfigPath = "core/audit"

// coreLocalAuditConfigPath is used to store audit information for local
// (non-replicated) mounts
coreLocalAuditConfigPath = "core/local-audit"

// auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends.
auditBarrierPrefix = "audit/"
Expand Down Expand Up @@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
}

// Generate a new UUID and view
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
if entry.UUID == "" {
entryUUID, err := uuid.GenerateUUID()
if err != nil {
return err
}
entry.UUID = entryUUID
}
entry.UUID = entryUUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)

// Lookup the new backend
backend, err := c.newAuditBackend(entry, view, entry.Options)
Expand Down Expand Up @@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) {

c.removeAuditReloadFunc(entry)

// When unmounting all entries the JSON code will load back up from storage
// as a nil slice, which kills tests...just set it nil explicitly
if len(newTable.Entries) == 0 {
newTable.Entries = nil
}

// Update the audit table
if err := c.persistAudit(newTable); err != nil {
return true, errors.New("failed to update audit table")
Expand All @@ -131,19 +143,26 @@ func (c *Core) disableAudit(path string) (bool, error) {
if c.logger.IsInfo() {
c.logger.Info("core: disabled audit backend", "path", path)
}

return true, nil
}

// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error {
auditTable := &MountTable{}
localAuditTable := &MountTable{}

// Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath)
if err != nil {
c.logger.Error("core: failed to read audit table", "error", err)
return errLoadAuditFailed
}
rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
c.logger.Error("core: failed to read local audit table", "error", err)
return errLoadAuditFailed
}

c.auditLock.Lock()
defer c.auditLock.Unlock()
Expand All @@ -155,6 +174,13 @@ func (c *Core) loadAudits() error {
}
c.audit = auditTable
}
if rawLocal != nil {
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
c.logger.Error("core: failed to decode local audit table", "error", err)
return errLoadAuditFailed
}
c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...)
}

// Done if we have restored the audit table
if c.audit != nil {
Expand Down Expand Up @@ -203,24 +229,58 @@ func (c *Core) persistAudit(table *MountTable) error {
}
}

nonLocalAudit := &MountTable{
Type: auditTableType,
}

localAudit := &MountTable{
Type: auditTableType,
}

for _, entry := range table.Entries {
if entry.Local {
localAudit.Entries = append(localAudit.Entries, entry)
} else {
nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry)
}
}

// Marshal the table
raw, err := json.Marshal(table)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode audit table", "error", err)
c.logger.Error("core: failed to encode and/or compress audit table", "error", err)
return err
}

// Create an entry
entry := &Entry{
Key: coreAuditConfigPath,
Value: raw,
Value: compressedBytes,
}

// Write to the physical backend
if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist audit table", "error", err)
return err
}

// Repeat with local audit
compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil)
if err != nil {
c.logger.Error("core: failed to encode and/or compress local audit table", "error", err)
return err
}

entry = &Entry{
Key: coreLocalAuditConfigPath,
Value: compressedBytes,
}

if err := c.barrier.Put(entry); err != nil {
c.logger.Error("core: failed to persist local audit table", "error", err)
return err
}

return nil
}

Expand All @@ -236,7 +296,8 @@ func (c *Core) setupAudits() error {

for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
viewPath := auditBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)

// Initialize the backend
audit, err := c.newAuditBackend(entry, view, entry.Options)
Expand Down
91 changes: 90 additions & 1 deletion vault/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
Expand Down Expand Up @@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) {
}
}

// Test that the local table actually gets populated as expected with local
// entries, and that upon reading the entries from both are recombined
// correctly
func TestCore_EnableAudit_Local(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}

c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return nil, fmt.Errorf("failing enabling")
}

c.audit = &MountTable{
Type: auditTableType,
Entries: []*MountEntry{
&MountEntry{
Table: auditTableType,
Path: "noop/",
Type: "noop",
UUID: "abcd",
},
&MountEntry{
Table: auditTableType,
Path: "noop2/",
Type: "noop",
UUID: "bcde",
},
},
}

// Both should set up successfully
err := c.setupAudits()
if err != nil {
t.Fatal(err)
}

rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable := &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) > 0 {
t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable)
}

c.audit.Entries[1].Local = true
if err := c.persistAudit(c.audit); err != nil {
t.Fatal(err)
}

rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath)
if err != nil {
t.Fatal(err)
}
if rawLocal == nil {
t.Fatal("expected non-nil local audit")
}
localAuditTable = &MountTable{}
if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil {
t.Fatal(err)
}
if len(localAuditTable.Entries) != 1 {
t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable)
}

oldAudit := c.audit
if err := c.loadAudits(); err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(oldAudit, c.audit) {
t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit)
}

if len(c.audit.Entries) != 2 {
t.Fatalf("expected two audit entries, got %#v", localAuditTable)
}
}

func TestCore_DisableAudit(t *testing.T) {
c, keys, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
Expand Down Expand Up @@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) {

// Verify matching mount tables
if !reflect.DeepEqual(c.audit, c2.audit) {
t.Fatalf("mismatch: %v %v", c.audit, c2.audit)
t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit)
}
}

Expand Down
66 changes: 44 additions & 22 deletions vault/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (

// This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct {
Type string `json:"type"`
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
Expand Down Expand Up @@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() {
c.logger.Info("core/stopClusterListener: success")
}

// ClusterTLSConfig generates a TLS configuration based on the local cluster
// key and cert.
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
// cluster key and cert.
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
cluster, err := c.Cluster()
if err != nil {
return nil, err
}
if cluster == nil {
return nil, fmt.Errorf("cluster information is nil")
return nil, fmt.Errorf("local cluster information is nil")
}

// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()

if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
return nil, fmt.Errorf("cluster certificate is nil")
}
forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0

parsedCert, err := x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
var parsedCert *x509.Certificate
if forwarding {
parsedCert, err = x509.ParseCertificate(c.localClusterCert)
if err != nil {
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
}

// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
}

// This is idempotent, so be sure it's been added
c.clusterCertPool.AddCert(parsedCert)
nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
c.clusterParamsLock.RLock()
defer c.clusterParamsLock.RUnlock()

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName {
return &tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
},
},
RootCAs: c.clusterCertPool,
ServerName: parsedCert.Subject.CommonName,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
MinVersion: tls.VersionTLS12,
}, nil
}

return nil, nil
}

var clientCertificates []tls.Certificate
if forwarding {
clientCertificates = append(clientCertificates, tls.Certificate{
Certificate: [][]byte{c.localClusterCert},
PrivateKey: c.localClusterPrivateKey,
})
}

tlsConfig := &tls.Config{
// We need this here for the client side
Certificates: clientCertificates,
RootCAs: c.clusterCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: c.clusterCertPool,
GetCertificate: nameLookup,
MinVersion: tls.VersionTLS12,
}
if forwarding {
tlsConfig.ServerName = parsedCert.Subject.CommonName
}

return tlsConfig, nil
Expand Down
Loading

0 comments on commit 2901591

Please sign in to comment.