From 2901591511d1792591b42dd20efb07943cd78717 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 16 Feb 2017 20:13:19 -0500 Subject: [PATCH] More porting from rep (#2389) * More porting from rep * Address feedback --- vault/audit.go | 81 ++++++++++++++++++++--- vault/audit_test.go | 91 +++++++++++++++++++++++++- vault/cluster.go | 66 ++++++++++++------- vault/core.go | 139 +++++++++++++++++++++++++++++----------- vault/logical_system.go | 19 +++++- vault/mount.go | 115 +++++++++++++++++++++++++++------ vault/mount_test.go | 91 ++++++++++++++++++++++++++ vault/policy_store.go | 20 +++++- 8 files changed, 530 insertions(+), 92 deletions(-) diff --git a/vault/audit.go b/vault/audit.go index 3df4cd96b1a2..b8974963eb14 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -2,7 +2,6 @@ package vault import ( "crypto/sha256" - "encoding/json" "errors" "fmt" "strings" @@ -26,6 +25,10 @@ const ( // can only be viewed or modified after an unseal. coreAuditConfigPath = "core/audit" + // coreLocalAuditConfigPath is used to store audit information for local + // (non-replicated) mounts + coreLocalAuditConfigPath = "core/local-audit" + // auditBarrierPrefix is the prefix to the UUID used in the // barrier view for the audit backends. auditBarrierPrefix = "audit/" @@ -69,12 +72,15 @@ func (c *Core) enableAudit(entry *MountEntry) error { } // Generate a new UUID and view - entryUUID, err := uuid.GenerateUUID() - if err != nil { - return err + if entry.UUID == "" { + entryUUID, err := uuid.GenerateUUID() + if err != nil { + return err + } + entry.UUID = entryUUID } - entry.UUID = entryUUID - view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") + viewPath := auditBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) // Lookup the new backend backend, err := c.newAuditBackend(entry, view, entry.Options) @@ -119,6 +125,12 @@ func (c *Core) disableAudit(path string) (bool, error) { c.removeAuditReloadFunc(entry) + // When unmounting all entries the JSON code will load back up from storage + // as a nil slice, which kills tests...just set it nil explicitly + if len(newTable.Entries) == 0 { + newTable.Entries = nil + } + // Update the audit table if err := c.persistAudit(newTable); err != nil { return true, errors.New("failed to update audit table") @@ -131,12 +143,14 @@ func (c *Core) disableAudit(path string) (bool, error) { if c.logger.IsInfo() { c.logger.Info("core: disabled audit backend", "path", path) } + return true, nil } // loadAudits is invoked as part of postUnseal to load the audit table func (c *Core) loadAudits() error { auditTable := &MountTable{} + localAuditTable := &MountTable{} // Load the existing audit table raw, err := c.barrier.Get(coreAuditConfigPath) @@ -144,6 +158,11 @@ func (c *Core) loadAudits() error { c.logger.Error("core: failed to read audit table", "error", err) return errLoadAuditFailed } + rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + c.logger.Error("core: failed to read local audit table", "error", err) + return errLoadAuditFailed + } c.auditLock.Lock() defer c.auditLock.Unlock() @@ -155,6 +174,13 @@ func (c *Core) loadAudits() error { } c.audit = auditTable } + if rawLocal != nil { + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + c.logger.Error("core: failed to decode local audit table", "error", err) + return errLoadAuditFailed + } + c.audit.Entries = append(c.audit.Entries, localAuditTable.Entries...) + } // Done if we have restored the audit table if c.audit != nil { @@ -203,17 +229,33 @@ func (c *Core) persistAudit(table *MountTable) error { } } + nonLocalAudit := &MountTable{ + Type: auditTableType, + } + + localAudit := &MountTable{ + Type: auditTableType, + } + + for _, entry := range table.Entries { + if entry.Local { + localAudit.Entries = append(localAudit.Entries, entry) + } else { + nonLocalAudit.Entries = append(nonLocalAudit.Entries, entry) + } + } + // Marshal the table - raw, err := json.Marshal(table) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAudit, nil) if err != nil { - c.logger.Error("core: failed to encode audit table", "error", err) + c.logger.Error("core: failed to encode and/or compress audit table", "error", err) return err } // Create an entry entry := &Entry{ Key: coreAuditConfigPath, - Value: raw, + Value: compressedBytes, } // Write to the physical backend @@ -221,6 +263,24 @@ func (c *Core) persistAudit(table *MountTable) error { c.logger.Error("core: failed to persist audit table", "error", err) return err } + + // Repeat with local audit + compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAudit, nil) + if err != nil { + c.logger.Error("core: failed to encode and/or compress local audit table", "error", err) + return err + } + + entry = &Entry{ + Key: coreLocalAuditConfigPath, + Value: compressedBytes, + } + + if err := c.barrier.Put(entry); err != nil { + c.logger.Error("core: failed to persist local audit table", "error", err) + return err + } + return nil } @@ -236,7 +296,8 @@ func (c *Core) setupAudits() error { for _, entry := range c.audit.Entries { // Create a barrier view using the UUID - view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") + viewPath := auditBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) // Initialize the backend audit, err := c.newAuditBackend(entry, view, entry.Options) diff --git a/vault/audit_test.go b/vault/audit_test.go index e1cd51cf9962..491be4915876 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/logical" log "github.com/mgutz/logxi/v1" @@ -164,6 +165,94 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) { } } +// Test that the local table actually gets populated as expected with local +// entries, and that upon reading the entries from both are recombined +// correctly +func TestCore_EnableAudit_Local(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil + } + + c.auditBackends["fail"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return nil, fmt.Errorf("failing enabling") + } + + c.audit = &MountTable{ + Type: auditTableType, + Entries: []*MountEntry{ + &MountEntry{ + Table: auditTableType, + Path: "noop/", + Type: "noop", + UUID: "abcd", + }, + &MountEntry{ + Table: auditTableType, + Path: "noop2/", + Type: "noop", + UUID: "bcde", + }, + }, + } + + // Both should set up successfully + err := c.setupAudits() + if err != nil { + t.Fatal(err) + } + + rawLocal, err := c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local audit") + } + localAuditTable := &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + t.Fatal(err) + } + if len(localAuditTable.Entries) > 0 { + t.Fatalf("expected no entries in local audit table, got %#v", localAuditTable) + } + + c.audit.Entries[1].Local = true + if err := c.persistAudit(c.audit); err != nil { + t.Fatal(err) + } + + rawLocal, err = c.barrier.Get(coreLocalAuditConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local audit") + } + localAuditTable = &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuditTable); err != nil { + t.Fatal(err) + } + if len(localAuditTable.Entries) != 1 { + t.Fatalf("expected one entry in local audit table, got %#v", localAuditTable) + } + + oldAudit := c.audit + if err := c.loadAudits(); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(oldAudit, c.audit) { + t.Fatalf("expected\n%#v\ngot\n%#v\n", oldAudit, c.audit) + } + + if len(c.audit.Entries) != 2 { + t.Fatalf("expected two audit entries, got %#v", localAuditTable) + } +} + func TestCore_DisableAudit(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { @@ -217,7 +306,7 @@ func TestCore_DisableAudit(t *testing.T) { // Verify matching mount tables if !reflect.DeepEqual(c.audit, c2.audit) { - t.Fatalf("mismatch: %v %v", c.audit, c2.audit) + t.Fatalf("mismatch:\n%#v\n%#v", c.audit, c2.audit) } } diff --git a/vault/cluster.go b/vault/cluster.go index 732080759ba7..1686c09cedc6 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -43,7 +43,7 @@ var ( // This can be one of a few key types so the different params may or may not be filled type clusterKeyParams struct { - Type string `json:"type"` + Type string `json:"type" structs:"type" mapstructure:"type"` X *big.Int `json:"x" structs:"x" mapstructure:"x"` Y *big.Int `json:"y" structs:"y" mapstructure:"y"` D *big.Int `json:"d" structs:"d" mapstructure:"d"` @@ -339,45 +339,67 @@ func (c *Core) stopClusterListener() { c.logger.Info("core/stopClusterListener: success") } -// ClusterTLSConfig generates a TLS configuration based on the local cluster -// key and cert. +// ClusterTLSConfig generates a TLS configuration based on the local/replicated +// cluster key and cert. func (c *Core) ClusterTLSConfig() (*tls.Config, error) { cluster, err := c.Cluster() if err != nil { return nil, err } if cluster == nil { - return nil, fmt.Errorf("cluster information is nil") + return nil, fmt.Errorf("local cluster information is nil") } // Prevent data races with the TLS parameters c.clusterParamsLock.Lock() defer c.clusterParamsLock.Unlock() - if c.localClusterCert == nil || len(c.localClusterCert) == 0 { - return nil, fmt.Errorf("cluster certificate is nil") - } + forwarding := c.localClusterCert != nil && len(c.localClusterCert) > 0 - parsedCert, err := x509.ParseCertificate(c.localClusterCert) - if err != nil { - return nil, fmt.Errorf("error parsing local cluster certificate: %v", err) + var parsedCert *x509.Certificate + if forwarding { + parsedCert, err = x509.ParseCertificate(c.localClusterCert) + if err != nil { + return nil, fmt.Errorf("error parsing local cluster certificate: %v", err) + } + + // This is idempotent, so be sure it's been added + c.clusterCertPool.AddCert(parsedCert) } - // This is idempotent, so be sure it's been added - c.clusterCertPool.AddCert(parsedCert) + nameLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + c.clusterParamsLock.RLock() + defer c.clusterParamsLock.RUnlock() - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{ - tls.Certificate{ + if forwarding && clientHello.ServerName == parsedCert.Subject.CommonName { + return &tls.Certificate{ Certificate: [][]byte{c.localClusterCert}, PrivateKey: c.localClusterPrivateKey, - }, - }, - RootCAs: c.clusterCertPool, - ServerName: parsedCert.Subject.CommonName, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: c.clusterCertPool, - MinVersion: tls.VersionTLS12, + }, nil + } + + return nil, nil + } + + var clientCertificates []tls.Certificate + if forwarding { + clientCertificates = append(clientCertificates, tls.Certificate{ + Certificate: [][]byte{c.localClusterCert}, + PrivateKey: c.localClusterPrivateKey, + }) + } + + tlsConfig := &tls.Config{ + // We need this here for the client side + Certificates: clientCertificates, + RootCAs: c.clusterCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: c.clusterCertPool, + GetCertificate: nameLookup, + MinVersion: tls.VersionTLS12, + } + if forwarding { + tlsConfig.ServerName = parsedCert.Subject.CommonName } return tlsConfig, nil diff --git a/vault/core.go b/vault/core.go index f3b9bf696c1f..02a1c010c874 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1,9 +1,9 @@ package vault import ( - "bytes" "crypto" "crypto/ecdsa" + "crypto/subtle" "crypto/x509" "errors" "fmt" @@ -57,6 +57,11 @@ const ( // leaderPrefixCleanDelay is how long to wait between deletions // of orphaned leader keys, to prevent slamming the backend. leaderPrefixCleanDelay = 200 * time.Millisecond + + // coreKeyringCanaryPath is used as a canary to indicate to replicated + // clusters that they need to perform a rekey operation synchronously; this + // isn't keyring-canary to avoid ignoring it when ignoring core/keyring + coreKeyringCanaryPath = "core/canary-keyring" ) var ( @@ -80,6 +85,12 @@ var ( // step down of the active node, to prevent instantly regrabbing the lock. // It's var not const so that tests can manipulate it. manualStepDownSleepPeriod = 10 * time.Second + + // Functions only in the Enterprise version + enterprisePostUnseal = enterprisePostUnsealImpl + enterprisePreSeal = enterprisePreSealImpl + startReplication = startReplicationImpl + stopReplication = stopReplicationImpl ) // ReloadFunc are functions that are called when a reload is requested. @@ -126,6 +137,11 @@ type unlockInformation struct { // interface for API handlers and is responsible for managing the logical and physical // backends, router, security barrier, and audit trails. type Core struct { + // N.B.: This is used to populate a dev token down replication, as + // otherwise, after replication is started, a dev would have to go through + // the generate-root process simply to talk to the new follower cluster. + devToken string + // HABackend may be available depending on the physical backend ha physical.HABackend @@ -261,7 +277,7 @@ type Core struct { // // Name clusterName string - // Used to modify cluster TLS params + // Used to modify cluster parameters clusterParamsLock sync.RWMutex // The private key stored in the barrier used for establishing // mutually-authenticated connections between Vault cluster members @@ -308,6 +324,8 @@ type Core struct { // CoreConfig is used to parameterize a core type CoreConfig struct { + DevToken string `json:"dev_token" structs:"dev_token" mapstructure:"dev_token"` + LogicalBackends map[string]logical.Factory `json:"logical_backends" structs:"logical_backends" mapstructure:"logical_backends"` CredentialBackends map[string]logical.Factory `json:"credential_backends" structs:"credential_backends" mapstructure:"credential_backends"` @@ -383,35 +401,12 @@ func NewCore(conf *CoreConfig) (*Core, error) { conf.Logger = logformat.NewVaultLogger(log.LevelTrace) } - if !conf.DisableMlock { - // Ensure our memory usage is locked into physical RAM - if err := mlock.LockMemory(); err != nil { - return nil, fmt.Errorf( - "Failed to lock memory: %v\n\n"+ - "This usually means that the mlock syscall is not available.\n"+ - "Vault uses mlock to prevent memory from being swapped to\n"+ - "disk. This requires root privileges as well as a machine\n"+ - "that supports mlock. Please enable mlock on your system or\n"+ - "disable Vault from using it. To disable Vault from using it,\n"+ - "set the `disable_mlock` configuration option in your configuration\n"+ - "file.", - err) - } - } - - // Construct a new AES-GCM barrier - barrier, err := NewAESGCMBarrier(conf.Physical) - if err != nil { - return nil, fmt.Errorf("barrier setup failed: %v", err) - } - // Setup the core c := &Core{ redirectAddr: conf.RedirectAddr, clusterAddr: conf.ClusterAddr, physical: conf.Physical, seal: conf.Seal, - barrier: barrier, router: NewRouter(), sealed: true, standby: true, @@ -425,11 +420,34 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterListenerShutdownSuccessCh: make(chan struct{}), } - // Wrap the backend in a cache unless disabled + // Wrap the physical backend in a cache layer if enabled and not already wrapped if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache { c.physical = physical.NewCache(conf.Physical, conf.CacheSize, conf.Logger) } + if !conf.DisableMlock { + // Ensure our memory usage is locked into physical RAM + if err := mlock.LockMemory(); err != nil { + return nil, fmt.Errorf( + "Failed to lock memory: %v\n\n"+ + "This usually means that the mlock syscall is not available.\n"+ + "Vault uses mlock to prevent memory from being swapped to\n"+ + "disk. This requires root privileges as well as a machine\n"+ + "that supports mlock. Please enable mlock on your system or\n"+ + "disable Vault from using it. To disable Vault from using it,\n"+ + "set the `disable_mlock` configuration option in your configuration\n"+ + "file.", + err) + } + } + + // Construct a new AES-GCM barrier + var err error + c.barrier, err = NewAESGCMBarrier(c.physical) + if err != nil { + return nil, fmt.Errorf("barrier setup failed: %v", err) + } + if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() { c.ha = conf.HAPhysical } @@ -796,17 +814,29 @@ func (c *Core) Unseal(key []byte) (bool, error) { return true, nil } + masterKey, err := c.unsealPart(config, key) + if err != nil { + return false, err + } + if masterKey != nil { + return c.unsealInternal(masterKey) + } + + return false, nil +} + +func (c *Core) unsealPart(config *SealConfig, key []byte) ([]byte, error) { // Check if we already have this piece if c.unlockInfo != nil { for _, existing := range c.unlockInfo.Parts { - if bytes.Equal(existing, key) { - return false, nil + if subtle.ConstantTimeCompare(existing, key) == 1 { + return nil, nil } } } else { uuid, err := uuid.GenerateUUID() if err != nil { - return false, err + return nil, err } c.unlockInfo = &unlockInformation{ Nonce: uuid, @@ -821,27 +851,37 @@ func (c *Core) Unseal(key []byte) (bool, error) { if c.logger.IsDebug() { c.logger.Debug("core: cannot unseal, not enough keys", "keys", len(c.unlockInfo.Parts), "threshold", config.SecretThreshold, "nonce", c.unlockInfo.Nonce) } - return false, nil + return nil, nil } + // Best-effort memzero of unlock parts once we're done with them + defer func() { + for i, _ := range c.unlockInfo.Parts { + memzero(c.unlockInfo.Parts[i]) + } + c.unlockInfo = nil + }() + // Recover the master key var masterKey []byte + var err error if config.SecretThreshold == 1 { - masterKey = c.unlockInfo.Parts[0] - c.unlockInfo = nil + masterKey = make([]byte, len(c.unlockInfo.Parts[0])) + copy(masterKey, c.unlockInfo.Parts[0]) } else { masterKey, err = shamir.Combine(c.unlockInfo.Parts) - c.unlockInfo = nil if err != nil { - return false, fmt.Errorf("failed to compute master key: %v", err) + return nil, fmt.Errorf("failed to compute master key: %v", err) } } - defer memzero(masterKey) - return c.unsealInternal(masterKey) + return masterKey, nil } +// This must be called with the state write lock held func (c *Core) unsealInternal(masterKey []byte) (bool, error) { + defer memzero(masterKey) + // Attempt to unlock if err := c.barrier.Unseal(masterKey); err != nil { return false, err @@ -860,12 +900,14 @@ func (c *Core) unsealInternal(masterKey []byte) (bool, error) { c.logger.Warn("core: vault is sealed") return false, err } + if err := c.postUnseal(); err != nil { c.logger.Error("core: post-unseal setup failed", "error", err) c.barrier.Seal() c.logger.Warn("core: vault is sealed") return false, err } + c.standby = false } else { // Go to standby mode, wait until we are active to unseal @@ -1161,6 +1203,7 @@ func (c *Core) postUnseal() (retErr error) { if purgable, ok := c.physical.(physical.Purgable); ok { purgable.Purge() } + // HA mode requires us to handle keyring rotation and rekeying if c.ha != nil { // We want to reload these from disk so that in case of a rekey we're @@ -1183,6 +1226,9 @@ func (c *Core) postUnseal() (retErr error) { return err } } + if err := enterprisePostUnseal(c); err != nil { + return err + } if err := c.ensureWrappingKey(); err != nil { return err } @@ -1244,6 +1290,7 @@ func (c *Core) preSeal() error { c.metricsCh = nil } var result error + if c.ha != nil { c.stopClusterListener() } @@ -1266,6 +1313,10 @@ func (c *Core) preSeal() error { if err := c.unloadMounts(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error unloading mounts: {{err}}", err)) } + if err := enterprisePreSeal(c); err != nil { + result = multierror.Append(result, err) + } + // Purge the backend if supported if purgable, ok := c.physical.(physical.Purgable); ok { purgable.Purge() @@ -1274,6 +1325,22 @@ func (c *Core) preSeal() error { return result } +func enterprisePostUnsealImpl(c *Core) error { + return nil +} + +func enterprisePreSealImpl(c *Core) error { + return nil +} + +func startReplicationImpl(c *Core) error { + return nil +} + +func stopReplicationImpl(c *Core) error { + return nil +} + // runStandby is a long running routine that is used when an HA backend // is enabled. It waits until we are leader and switches this Vault to // active. diff --git a/vault/logical_system.go b/vault/logical_system.go index e0f1163214ae..b3756bcc948a 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -22,6 +22,23 @@ var ( protectedPaths = []string{ "core", } + + replicationPaths = []*framework.Path{ + &framework.Path{ + Pattern: "replication/status", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + var state consts.ReplicationState + resp := &logical.Response{ + Data: map[string]interface{}{ + "mode": state.String(), + }, + } + return resp, nil + }, + }, + }, + } ) func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backend, error) { @@ -675,7 +692,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen }, } - b.Backend.Paths = append(b.Backend.Paths, b.replicationPaths()...) + b.Backend.Paths = append(b.Backend.Paths, replicationPaths...) b.Backend.Invalidate = b.invalidate diff --git a/vault/mount.go b/vault/mount.go index 072583ae9168..a7e29a2a935d 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -19,6 +19,10 @@ const ( // can only be viewed or modified after an unseal. coreMountConfigPath = "core/mounts" + // coreLocalMountConfigPath is used to store mount configuration for local + // (non-replicated) mounts + coreLocalMountConfigPath = "core/local-mounts" + // backendBarrierPrefix is the prefix to the UUID used in the // barrier view for the backends. backendBarrierPrefix = "logical/" @@ -124,6 +128,7 @@ type MountEntry struct { UUID string `json:"uuid"` // Barrier view UUID Config MountConfig `json:"config"` // Configuration related to this mount (but not backend-derived) Options map[string]string `json:"options"` // Backend options + Local bool `json:"local"` // Local mounts are not replicated or affected by replication Tainted bool `json:"tainted,omitempty"` // Set as a Write-Ahead flag for unmount/remount } @@ -147,28 +152,29 @@ func (e *MountEntry) Clone() *MountEntry { UUID: e.UUID, Config: e.Config, Options: optClone, + Local: e.Local, Tainted: e.Tainted, } } // Mount is used to mount a new backend to the mount table. -func (c *Core) mount(me *MountEntry) error { +func (c *Core) mount(entry *MountEntry) error { // Ensure we end the path in a slash - if !strings.HasSuffix(me.Path, "/") { - me.Path += "/" + if !strings.HasSuffix(entry.Path, "/") { + entry.Path += "/" } // Prevent protected paths from being mounted for _, p := range protectedMounts { - if strings.HasPrefix(me.Path, p) { - return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path)) + if strings.HasPrefix(entry.Path, p) { + return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", entry.Path)) } } // Do not allow more than one instance of a singleton mount for _, p := range singletonMounts { - if me.Type == p { - return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", me.Type)) + if entry.Type == p { + return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", entry.Type)) } } @@ -176,37 +182,47 @@ func (c *Core) mount(me *MountEntry) error { defer c.mountsLock.Unlock() // Verify there is no conflicting mount - if match := c.router.MatchingMount(me.Path); match != "" { + if match := c.router.MatchingMount(entry.Path); match != "" { return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match)) } // Generate a new UUID and view - meUUID, err := uuid.GenerateUUID() + if entry.UUID == "" { + entryUUID, err := uuid.GenerateUUID() + if err != nil { + return err + } + entry.UUID = entryUUID + } + viewPath := backendBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) + sysView := c.mountEntrySysView(entry) + + backend, err := c.newLogicalBackend(entry.Type, sysView, view, nil) if err != nil { return err } - me.UUID = meUUID - view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/") - backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil) - if err != nil { + // Call initialize; this takes care of init tasks that must be run after + // the ignore paths are collected + if err := backend.Initialize(); err != nil { return err } newTable := c.mounts.shallowClone() - newTable.Entries = append(newTable.Entries, me) + newTable.Entries = append(newTable.Entries, entry) if err := c.persistMounts(newTable); err != nil { c.logger.Error("core: failed to update mount table", "error", err) return logical.CodedError(500, "failed to update mount table") } c.mounts = newTable - if err := c.router.Mount(backend, me.Path, me, view); err != nil { + if err := c.router.Mount(backend, entry.Path, entry, view); err != nil { return err } if c.logger.IsInfo() { - c.logger.Info("core: successful mount", "path", me.Path, "type", me.Type) + c.logger.Info("core: successful mount", "path", entry.Path, "type", entry.Type) } return nil } @@ -291,6 +307,12 @@ func (c *Core) removeMountEntry(path string) error { newTable := c.mounts.shallowClone() newTable.remove(path) + // When unmounting all entries the JSON code will load back up from storage + // as a nil slice, which kills tests...just set it nil explicitly + if len(newTable.Entries) == 0 { + newTable.Entries = nil + } + // Update the mount table if err := c.persistMounts(newTable); err != nil { c.logger.Error("core: failed to update mount table", "error", err) @@ -405,12 +427,18 @@ func (c *Core) remount(src, dst string) error { // loadMounts is invoked as part of postUnseal to load the mount table func (c *Core) loadMounts() error { mountTable := &MountTable{} + localMountTable := &MountTable{} // Load the existing mount table raw, err := c.barrier.Get(coreMountConfigPath) if err != nil { c.logger.Error("core: failed to read mount table", "error", err) return errLoadMountsFailed } + rawLocal, err := c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + c.logger.Error("core: failed to read local mount table", "error", err) + return errLoadMountsFailed + } c.mountsLock.Lock() defer c.mountsLock.Unlock() @@ -425,6 +453,13 @@ func (c *Core) loadMounts() error { } c.mounts = mountTable } + if rawLocal != nil { + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountTable); err != nil { + c.logger.Error("core: failed to decompress and/or decode the local mount table", "error", err) + return err + } + c.mounts.Entries = append(c.mounts.Entries, localMountTable.Entries...) + } // Ensure that required entries are loaded, or new ones // added may never get loaded at all. Note that this @@ -492,8 +527,24 @@ func (c *Core) persistMounts(table *MountTable) error { } } + nonLocalMounts := &MountTable{ + Type: mountTableType, + } + + localMounts := &MountTable{ + Type: mountTableType, + } + + for _, entry := range table.Entries { + if entry.Local { + localMounts.Entries = append(localMounts.Entries, entry) + } else { + nonLocalMounts.Entries = append(nonLocalMounts.Entries, entry) + } + } + // Encode the mount table into JSON and compress it (lzw). - compressedBytes, err := jsonutil.EncodeJSONAndCompress(table, nil) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil) if err != nil { c.logger.Error("core: failed to encode and/or compress the mount table", "error", err) return err @@ -510,6 +561,24 @@ func (c *Core) persistMounts(table *MountTable) error { c.logger.Error("core: failed to persist mount table", "error", err) return err } + + // Repeat with local mounts + compressedBytes, err = jsonutil.EncodeJSONAndCompress(localMounts, nil) + if err != nil { + c.logger.Error("core: failed to encode and/or compress the local mount table", "error", err) + return err + } + + entry = &Entry{ + Key: coreLocalMountConfigPath, + Value: compressedBytes, + } + + if err := c.barrier.Put(entry); err != nil { + c.logger.Error("core: failed to persist local mount table", "error", err) + return err + } + return nil } @@ -532,15 +601,19 @@ func (c *Core) setupMounts() error { // Create a barrier view using the UUID view = NewBarrierView(c.barrier, barrierPath) - + sysView := c.mountEntrySysView(entry) // Initialize the backend // Create the new backend - backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil) + backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil) if err != nil { c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) return errLoadMountsFailed } + if err := backend.Initialize(); err != nil { + return err + } + switch entry.Type { case "system": c.systemBarrierView = view @@ -616,10 +689,10 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi // mountEntrySysView creates a logical.SystemView from global and // mount-specific entries; because this should be called when setting // up a mountEntry, it doesn't check to ensure that me is not nil -func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView { +func (c *Core) mountEntrySysView(entry *MountEntry) logical.SystemView { return dynamicSystemView{ core: c, - mountEntry: me, + mountEntry: entry, } } diff --git a/vault/mount_test.go b/vault/mount_test.go index dd6ef59445cc..a00d37945463 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/compressutil" + "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/logical" ) @@ -82,6 +83,96 @@ func TestCore_Mount(t *testing.T) { } } +// Test that the local table actually gets populated as expected with local +// entries, and that upon reading the entries from both are recombined +// correctly +func TestCore_Mount_Local(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + c.mounts = &MountTable{ + Type: mountTableType, + Entries: []*MountEntry{ + &MountEntry{ + Table: mountTableType, + Path: "noop/", + Type: "generic", + UUID: "abcd", + }, + &MountEntry{ + Table: mountTableType, + Path: "noop2/", + Type: "generic", + UUID: "bcde", + }, + }, + } + + // Both should set up successfully + err := c.setupMounts() + if err != nil { + t.Fatal(err) + } + if len(c.mounts.Entries) != 2 { + t.Fatalf("expected two entries, got %d", len(c.mounts.Entries)) + } + + rawLocal, err := c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local mounts") + } + localMountsTable := &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil { + t.Fatal(err) + } + if len(localMountsTable.Entries) > 0 { + t.Fatalf("expected no entries in local mount table, got %#v", localMountsTable) + } + + c.mounts.Entries[1].Local = true + if err := c.persistMounts(c.mounts); err != nil { + t.Fatal(err) + } + + rawLocal, err = c.barrier.Get(coreLocalMountConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local mount") + } + localMountsTable = &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localMountsTable); err != nil { + t.Fatal(err) + } + if len(localMountsTable.Entries) != 1 { + t.Fatalf("expected one entry in local mount table, got %#v", localMountsTable) + } + + oldMounts := c.mounts + if err := c.loadMounts(); err != nil { + t.Fatal(err) + } + compEntries := c.mounts.Entries[:0] + // Filter out required mounts + for _, v := range c.mounts.Entries { + if v.Type == "generic" { + compEntries = append(compEntries, v) + } + } + c.mounts.Entries = compEntries + + if !reflect.DeepEqual(oldMounts, c.mounts) { + t.Fatalf("expected\n%#v\ngot\n%#v\n", oldMounts, c.mounts) + } + + if len(c.mounts.Entries) != 2 { + t.Fatalf("expected two mount entries, got %#v", localMountsTable) + } +} + func TestCore_Unmount(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) existed, err := c.unmount("secret") diff --git a/vault/policy_store.go b/vault/policy_store.go index 873a8dde8bbd..8200f22cdd50 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -2,11 +2,13 @@ package vault import ( "fmt" + "strings" "time" "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" "github.com/hashicorp/golang-lru" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) @@ -137,7 +139,13 @@ func (c *Core) setupPolicyStore() error { view := c.systemBarrierView.SubView(policySubPath) // Create the policy store - c.policyStore = NewPolicyStore(view, &dynamicSystemView{core: c}) + sysView := &dynamicSystemView{core: c} + c.policyStore = NewPolicyStore(view, sysView) + + if sysView.ReplicationState() == consts.ReplicationSecondary { + // Policies will sync from the primary + return nil + } // Ensure that the default policy exists, and if not, create it policy, err := c.policyStore.GetPolicy("default") @@ -173,6 +181,16 @@ func (c *Core) teardownPolicyStore() error { return nil } +func (ps *PolicyStore) invalidate(name string) { + if ps.lru == nil { + // Nothing to do if the cache is not used + return + } + + // This may come with a prefixed "/" due to joining the file path + ps.lru.Remove(strings.TrimPrefix(name, "/")) +} + // SetPolicy is used to create or update the given policy func (ps *PolicyStore) SetPolicy(p *Policy) error { defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())