From 38fc7576737a38232760ebd0a0ee5b1b221ab275 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 16 Feb 2017 22:37:37 -0500 Subject: [PATCH 1/2] More rep porting --- vault/auth.go | 92 ++++++++++++++++++++++++++++++----- vault/auth_test.go | 86 +++++++++++++++++++++++++++++++- vault/barrier.go | 19 ++++++++ vault/barrier_aes_gcm.go | 91 ++++++++++++++++++++++++++++++---- vault/barrier_aes_gcm_test.go | 27 ++++++++++ vault/barrier_view.go | 25 +++++++--- vault/capabilities.go | 22 +++------ vault/init.go | 51 ++++++++++--------- vault/rekey_test.go | 4 +- vault/request_handling.go | 10 ++-- vault/testing.go | 6 ++- vault/token_store.go | 41 +++++++++++----- vault/token_store_test.go | 6 +++ 13 files changed, 388 insertions(+), 92 deletions(-) diff --git a/vault/auth.go b/vault/auth.go index 399626ca1fbe..5c8315645f5b 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -1,7 +1,6 @@ package vault import ( - "encoding/json" "errors" "fmt" "strings" @@ -17,6 +16,10 @@ const ( // can only be viewed or modified after an unseal. coreAuthConfigPath = "core/auth" + // coreLocalAuthConfigPath is used to store credential configuration for + // local (non-replicated) mounts + coreLocalAuthConfigPath = "core/local-auth" + // credentialBarrierPrefix is the prefix to the UUID used in the // barrier view for the credential backends. credentialBarrierPrefix = "auth/" @@ -71,19 +74,28 @@ func (c *Core) enableCredential(entry *MountEntry) error { } // Generate a new UUID and view - entryUUID, err := uuid.GenerateUUID() - if err != nil { - return err + if entry.UUID == "" { + entryUUID, err := uuid.GenerateUUID() + if err != nil { + return err + } + entry.UUID = entryUUID } - entry.UUID = entryUUID - view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") + + viewPath := credentialBarrierPrefix + entry.UUID + "/" + view := NewBarrierView(c.barrier, viewPath) + sysView := c.mountEntrySysView(entry) // Create the new backend - backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil) + backend, err := c.newCredentialBackend(entry.Type, sysView, view, nil) if err != nil { return err } + if err := backend.Initialize(); err != nil { + return err + } + // Update the auth table newTable := c.auth.shallowClone() newTable.Entries = append(newTable.Entries, entry) @@ -121,7 +133,7 @@ func (c *Core) disableCredential(path string) (bool, error) { fullPath := credentialRoutePrefix + path view := c.router.MatchingStorageView(fullPath) if view == nil { - return false, fmt.Errorf("no matching backend") + return false, fmt.Errorf("no matching backend %s", fullPath) } // Mark the entry as tainted @@ -206,12 +218,19 @@ func (c *Core) taintCredEntry(path string) error { // loadCredentials is invoked as part of postUnseal to load the auth table func (c *Core) loadCredentials() error { authTable := &MountTable{} + localAuthTable := &MountTable{} + // Load the existing mount table raw, err := c.barrier.Get(coreAuthConfigPath) if err != nil { c.logger.Error("core: failed to read auth table", "error", err) return errLoadAuthFailed } + rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath) + if err != nil { + c.logger.Error("core: failed to read local auth table", "error", err) + return errLoadAuthFailed + } c.authLock.Lock() defer c.authLock.Unlock() @@ -223,6 +242,13 @@ func (c *Core) loadCredentials() error { } c.auth = authTable } + if rawLocal != nil { + if err := jsonutil.DecodeJSON(rawLocal.Value, localAuthTable); err != nil { + c.logger.Error("core: failed to decode local auth table", "error", err) + return errLoadAuthFailed + } + c.auth.Entries = append(c.auth.Entries, localAuthTable.Entries...) + } // Done if we have restored the auth table if c.auth != nil { @@ -272,17 +298,33 @@ func (c *Core) persistAuth(table *MountTable) error { } } + nonLocalAuth := &MountTable{ + Type: credentialTableType, + } + + localAuth := &MountTable{ + Type: credentialTableType, + } + + for _, entry := range table.Entries { + if entry.Local { + localAuth.Entries = append(localAuth.Entries, entry) + } else { + nonLocalAuth.Entries = append(nonLocalAuth.Entries, entry) + } + } + // Marshal the table - raw, err := json.Marshal(table) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil) if err != nil { - c.logger.Error("core: failed to encode auth table", "error", err) + c.logger.Error("core: failed to encode and/or compress auth table", "error", err) return err } // Create an entry entry := &Entry{ Key: coreAuthConfigPath, - Value: raw, + Value: compressedBytes, } // Write to the physical backend @@ -290,6 +332,24 @@ func (c *Core) persistAuth(table *MountTable) error { c.logger.Error("core: failed to persist auth table", "error", err) return err } + + // Repeat with local auth + compressedBytes, err = jsonutil.EncodeJSONAndCompress(localAuth, nil) + if err != nil { + c.logger.Error("core: failed to encode and/or compress local auth table", "error", err) + return err + } + + entry = &Entry{ + Key: coreLocalAuthConfigPath, + Value: compressedBytes, + } + + if err := c.barrier.Put(entry); err != nil { + c.logger.Error("core: failed to persist local auth table", "error", err) + return err + } + return nil } @@ -312,15 +372,21 @@ func (c *Core) setupCredentials() error { } // Create a barrier view using the UUID - view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") + viewPath := credentialBarrierPrefix + entry.UUID + "/" + view = NewBarrierView(c.barrier, viewPath) + sysView := c.mountEntrySysView(entry) // Initialize the backend - backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil) + backend, err = c.newCredentialBackend(entry.Type, sysView, view, nil) if err != nil { c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) return errLoadAuthFailed } + if err := backend.Initialize(); err != nil { + return err + } + // Mount the backend path := credentialRoutePrefix + entry.Path err = c.router.Mount(backend, path, entry, view) diff --git a/vault/auth_test.go b/vault/auth_test.go index 75caa789fe84..41f9eb37bb6c 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -2,8 +2,10 @@ package vault import ( "reflect" + "strings" "testing" + "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/logical" ) @@ -84,6 +86,88 @@ func TestCore_EnableCredential(t *testing.T) { } } +// Test that the local table actually gets populated as expected with local +// entries, and that upon reading the entries from both are recombined +// correctly +func TestCore_EnableCredential_Local(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{}, nil + } + + c.auth = &MountTable{ + Type: credentialTableType, + Entries: []*MountEntry{ + &MountEntry{ + Table: credentialTableType, + Path: "noop/", + Type: "noop", + UUID: "abcd", + }, + &MountEntry{ + Table: credentialTableType, + Path: "noop2/", + Type: "noop", + UUID: "bcde", + }, + }, + } + + // Both should set up successfully + err := c.setupCredentials() + if err != nil { + t.Fatal(err) + } + + rawLocal, err := c.barrier.Get(coreLocalAuthConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local credential") + } + localCredentialTable := &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil { + t.Fatal(err) + } + if len(localCredentialTable.Entries) > 0 { + t.Fatalf("expected no entries in local credential table, got %#v", localCredentialTable) + } + + c.auth.Entries[1].Local = true + if err := c.persistAuth(c.auth); err != nil { + t.Fatal(err) + } + + rawLocal, err = c.barrier.Get(coreLocalAuthConfigPath) + if err != nil { + t.Fatal(err) + } + if rawLocal == nil { + t.Fatal("expected non-nil local credential") + } + localCredentialTable = &MountTable{} + if err := jsonutil.DecodeJSON(rawLocal.Value, localCredentialTable); err != nil { + t.Fatal(err) + } + if len(localCredentialTable.Entries) != 1 { + t.Fatalf("expected one entry in local credential table, got %#v", localCredentialTable) + } + + oldCredential := c.auth + if err := c.loadCredentials(); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(oldCredential, c.auth) { + t.Fatalf("expected\n%#v\ngot\n%#v\n", oldCredential, c.auth) + } + + if len(c.auth.Entries) != 2 { + t.Fatalf("expected two credential entries, got %#v", localCredentialTable) + } +} + func TestCore_EnableCredential_twice_409(t *testing.T) { c, _, _ := TestCoreUnsealed(t) c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { @@ -132,7 +216,7 @@ func TestCore_DisableCredential(t *testing.T) { } existed, err := c.disableCredential("foo") - if existed || err.Error() != "no matching backend" { + if existed || (err != nil && !strings.HasPrefix(err.Error(), "no matching backend")) { t.Fatalf("existed: %v; err: %v", existed, err) } diff --git a/vault/barrier.go b/vault/barrier.go index df0660d0fe5c..7c9acc0550ea 100644 --- a/vault/barrier.go +++ b/vault/barrier.go @@ -86,6 +86,11 @@ type SecurityBarrier interface { // VerifyMaster is used to check if the given key matches the master key VerifyMaster(key []byte) error + // SetMasterKey is used to directly set a new master key. This is used in + // repliated scenarios due to the chicken and egg problem of reloading the + // keyring from disk before we have the master key to decrypt it. + SetMasterKey(key []byte) error + // ReloadKeyring is used to re-read the underlying keyring. // This is used for HA deployments to ensure the latest keyring // is present in the leader. @@ -119,8 +124,14 @@ type SecurityBarrier interface { // Rekey is used to change the master key used to protect the keyring Rekey([]byte) error + // For replication we must send over the keyring, so this must be available + Keyring() (*Keyring, error) + // SecurityBarrier must provide the storage APIs BarrierStorage + + // SecurityBarrier must provide the encryption APIs + BarrierEncryptor } // BarrierStorage is the storage only interface required for a Barrier. @@ -139,6 +150,14 @@ type BarrierStorage interface { List(prefix string) ([]string, error) } +// BarrierEncryptor is the in memory only interface that does not actually +// use the underlying barrier. It is used for lower level modules like the +// Write-Ahead-Log and Merkle index to allow them to use the barrier. +type BarrierEncryptor interface { + Encrypt(key string, plaintext []byte) ([]byte, error) + Decrypt(key string, ciphertext []byte) ([]byte, error) +} + // Entry is used to represent data stored by the security barrier type Entry struct { Key string diff --git a/vault/barrier_aes_gcm.go b/vault/barrier_aes_gcm.go index 56ebeb8c0c92..37c191bd6b02 100644 --- a/vault/barrier_aes_gcm.go +++ b/vault/barrier_aes_gcm.go @@ -574,19 +574,12 @@ func (b *AESGCMBarrier) ActiveKeyInfo() (*KeyInfo, error) { func (b *AESGCMBarrier) Rekey(key []byte) error { b.l.Lock() defer b.l.Unlock() - if b.sealed { - return ErrBarrierSealed - } - // Verify the key size - min, max := b.KeyLength() - if len(key) < min || len(key) > max { - return fmt.Errorf("Key size must be %d or %d", min, max) + newKeyring, err := b.updateMasterKeyCommon(key) + if err != nil { + return err } - // Add a new encryption key - newKeyring := b.keyring.SetMasterKey(key) - // Persist the new keyring if err := b.persistKeyring(newKeyring); err != nil { return err @@ -599,6 +592,40 @@ func (b *AESGCMBarrier) Rekey(key []byte) error { return nil } +// SetMasterKey updates the keyring's in-memory master key but does not persist +// anything to storage +func (b *AESGCMBarrier) SetMasterKey(key []byte) error { + b.l.Lock() + defer b.l.Unlock() + + newKeyring, err := b.updateMasterKeyCommon(key) + if err != nil { + return err + } + + // Swap the keyrings + oldKeyring := b.keyring + b.keyring = newKeyring + oldKeyring.Zeroize(false) + return nil +} + +// Performs common tasks related to updating the master key; note that the lock +// must be held before calling this function +func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) { + if b.sealed { + return nil, ErrBarrierSealed + } + + // Verify the key size + min, max := b.KeyLength() + if len(key) < min || len(key) > max { + return nil, fmt.Errorf("Key size must be %d or %d", min, max) + } + + return b.keyring.SetMasterKey(key), nil +} + // Put is used to insert or update an entry func (b *AESGCMBarrier) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now()) @@ -813,3 +840,47 @@ func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, erro return nil, fmt.Errorf("version bytes mis-match") } } + +// Encrypt is used to encrypt in-memory for the BarrierEncryptor interface +func (b *AESGCMBarrier) Encrypt(key string, plaintext []byte) ([]byte, error) { + b.l.RLock() + defer b.l.RUnlock() + if b.sealed { + return nil, ErrBarrierSealed + } + + term := b.keyring.ActiveTerm() + primary, err := b.aeadForTerm(term) + if err != nil { + return nil, err + } + + ciphertext := b.encrypt(key, term, primary, plaintext) + return ciphertext, nil +} + +// Decrypt is used to decrypt in-memory for the BarrierEncryptor interface +func (b *AESGCMBarrier) Decrypt(key string, ciphertext []byte) ([]byte, error) { + b.l.RLock() + defer b.l.RUnlock() + if b.sealed { + return nil, ErrBarrierSealed + } + + // Decrypt the ciphertext + plain, err := b.decryptKeyring(key, ciphertext) + if err != nil { + return nil, fmt.Errorf("decryption failed: %v", err) + } + return plain, nil +} + +func (b *AESGCMBarrier) Keyring() (*Keyring, error) { + b.l.RLock() + defer b.l.RUnlock() + if b.sealed { + return nil, ErrBarrierSealed + } + + return b.keyring.Clone(), nil +} diff --git a/vault/barrier_aes_gcm_test.go b/vault/barrier_aes_gcm_test.go index 3814139bf5d4..7d575cee839d 100644 --- a/vault/barrier_aes_gcm_test.go +++ b/vault/barrier_aes_gcm_test.go @@ -433,3 +433,30 @@ func TestInitialize_KeyLength(t *testing.T) { t.Fatalf("key length protection failed") } } + +func TestEncrypt_BarrierEncryptor(t *testing.T) { + inm := physical.NewInmem(logger) + b, err := NewAESGCMBarrier(inm) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Initialize and unseal + key, _ := b.GenerateKey() + b.Initialize(key) + b.Unseal(key) + + cipher, err := b.Encrypt("foo", []byte("quick brown fox")) + if err != nil { + t.Fatalf("err: %v", err) + } + + plain, err := b.Decrypt("foo", cipher) + if err != nil { + t.Fatalf("err: %v", err) + } + + if string(plain) != "quick brown fox" { + t.Fatalf("bad: %s", plain) + } +} diff --git a/vault/barrier_view.go b/vault/barrier_view.go index b0dbbf2d1d38..0fa6f2d78f4f 100644 --- a/vault/barrier_view.go +++ b/vault/barrier_view.go @@ -69,14 +69,18 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) { // logical.Storage impl. func (v *BarrierView) Put(entry *logical.StorageEntry) error { - if v.readonly { - return logical.ErrReadOnly - } if err := v.sanityCheck(entry.Key); err != nil { return err } + + expandedKey := v.expandKey(entry.Key) + + if v.readonly { + return logical.ErrReadOnly + } + nested := &Entry{ - Key: v.expandKey(entry.Key), + Key: expandedKey, Value: entry.Value, } return v.barrier.Put(nested) @@ -84,13 +88,18 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error { // logical.Storage impl. func (v *BarrierView) Delete(key string) error { - if v.readonly { - return logical.ErrReadOnly - } if err := v.sanityCheck(key); err != nil { return err } - return v.barrier.Delete(v.expandKey(key)) + + expandedKey := v.expandKey(key) + + if v.readonly { + return logical.ErrReadOnly + } + + + return v.barrier.Delete(expandedKey) } // SubView constructs a nested sub-view using the given prefix diff --git a/vault/capabilities.go b/vault/capabilities.go index 4d3add4950ff..6994e52edff1 100644 --- a/vault/capabilities.go +++ b/vault/capabilities.go @@ -1,27 +1,19 @@ package vault -import "sort" +import ( + "sort" -// Struct to identify user input errors. -// This is helpful in responding the appropriate status codes to clients -// from the HTTP endpoints. -type StatusBadRequest struct { - Err string -} - -// Implementing error interface -func (s *StatusBadRequest) Error() string { - return s.Err -} + "github.com/hashicorp/vault/logical" +) // Capabilities is used to fetch the capabilities of the given token on the given path func (c *Core) Capabilities(token, path string) ([]string, error) { if path == "" { - return nil, &StatusBadRequest{Err: "missing path"} + return nil, &logical.StatusBadRequest{Err: "missing path"} } if token == "" { - return nil, &StatusBadRequest{Err: "missing token"} + return nil, &logical.StatusBadRequest{Err: "missing token"} } te, err := c.tokenStore.Lookup(token) @@ -29,7 +21,7 @@ func (c *Core) Capabilities(token, path string) ([]string, error) { return nil, err } if te == nil { - return nil, &StatusBadRequest{Err: "invalid token"} + return nil, &logical.StatusBadRequest{Err: "invalid token"} } if te.Policies == nil { diff --git a/vault/init.go b/vault/init.go index 221a9bd051c3..3e267fd011d9 100644 --- a/vault/init.go +++ b/vault/init.go @@ -133,36 +133,12 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) { return nil, fmt.Errorf("error initializing seal: %v", err) } - err = c.seal.SetBarrierConfig(barrierConfig) - if err != nil { - c.logger.Error("core: failed to save barrier configuration", "error", err) - return nil, fmt.Errorf("barrier configuration saving failed: %v", err) - } - barrierKey, barrierUnsealKeys, err := c.generateShares(barrierConfig) if err != nil { c.logger.Error("core: error generating shares", "error", err) return nil, err } - // If we are storing shares, pop them out of the returned results and push - // them through the seal - if barrierConfig.StoredShares > 0 { - var keysToStore [][]byte - for i := 0; i < barrierConfig.StoredShares; i++ { - keysToStore = append(keysToStore, barrierUnsealKeys[0]) - barrierUnsealKeys = barrierUnsealKeys[1:] - } - if err := c.seal.SetStoredKeys(keysToStore); err != nil { - c.logger.Error("core: failed to store keys", "error", err) - return nil, fmt.Errorf("failed to store keys: %v", err) - } - } - - results := &InitResult{ - SecretShares: barrierUnsealKeys, - } - // Initialize the barrier if err := c.barrier.Initialize(barrierKey); err != nil { c.logger.Error("core: failed to initialize barrier", "error", err) @@ -180,11 +156,38 @@ func (c *Core) Initialize(initParams *InitParams) (*InitResult, error) { // Ensure the barrier is re-sealed defer func() { + // Defers are LIFO so we need to run this here too to ensure the stop + // happens before sealing. preSeal also stops, so we just make the + // stopping safe against multiple calls. if err := c.barrier.Seal(); err != nil { c.logger.Error("core: failed to seal barrier", "error", err) } }() + err = c.seal.SetBarrierConfig(barrierConfig) + if err != nil { + c.logger.Error("core: failed to save barrier configuration", "error", err) + return nil, fmt.Errorf("barrier configuration saving failed: %v", err) + } + + // If we are storing shares, pop them out of the returned results and push + // them through the seal + if barrierConfig.StoredShares > 0 { + var keysToStore [][]byte + for i := 0; i < barrierConfig.StoredShares; i++ { + keysToStore = append(keysToStore, barrierUnsealKeys[0]) + barrierUnsealKeys = barrierUnsealKeys[1:] + } + if err := c.seal.SetStoredKeys(keysToStore); err != nil { + c.logger.Error("core: failed to store keys", "error", err) + return nil, fmt.Errorf("failed to store keys: %v", err) + } + } + + results := &InitResult{ + SecretShares: barrierUnsealKeys, + } + // Perform initial setup if err := c.setupCluster(); err != nil { c.logger.Error("core: cluster setup failed during init", "error", err) diff --git a/vault/rekey_test.go b/vault/rekey_test.go index dacba36d9dd8..c463325fe4bc 100644 --- a/vault/rekey_test.go +++ b/vault/rekey_test.go @@ -237,7 +237,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str t.Fatalf("err: %v", err) } for i := 0; i < 3; i++ { - _, err = TestCoreUnseal(c, result.SecretShares[i]) + _, err = TestCoreUnseal(c, TestKeyCopy(result.SecretShares[i])) if err != nil { t.Fatalf("err: %v", err) } @@ -270,7 +270,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str // Provide the parts master oldResult := result for i := 0; i < 3; i++ { - result, err = c.RekeyUpdate(oldResult.SecretShares[i], rkconf.Nonce, recovery) + result, err = c.RekeyUpdate(TestKeyCopy(oldResult.SecretShares[i]), rkconf.Nonce, recovery) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/request_handling.go b/vault/request_handling.go index 2c7c45434952..c73270f2aa7e 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -184,7 +184,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r } // Route the request - resp, err := c.router.Route(req) + resp, routeErr := c.router.Route(req) if resp != nil { // If wrapping is used, use the shortest between the request and response var wrapTTL time.Duration @@ -306,8 +306,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r } // Return the response and error - if err != nil { - retErr = multierror.Append(retErr, err) + if routeErr != nil { + retErr = multierror.Append(retErr, routeErr) } return resp, auth, retErr } @@ -331,7 +331,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log } // Route the request - resp, err := c.router.Route(req) + resp, routeErr := c.router.Route(req) if resp != nil { // If wrapping is used, use the shortest between the request and response var wrapTTL time.Duration @@ -446,5 +446,5 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log req.DisplayName = auth.DisplayName } - return resp, auth, err + return resp, auth, routeErr } diff --git a/vault/testing.go b/vault/testing.go index 6b7706ce2879..2229c7b996f1 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -243,8 +243,12 @@ func testTokenStore(t testing.TB, c *Core) *TokenStore { me.UUID = meUUID view := NewBarrierView(c.barrier, credentialBarrierPrefix+me.UUID+"/") + sysView := c.mountEntrySysView(me) - tokenstore, _ := c.newCredentialBackend("token", c.mountEntrySysView(me), view, nil) + tokenstore, _ := c.newCredentialBackend("token", sysView, view, nil) + if err := tokenstore.Initialize(); err != nil { + panic(err) + } ts := tokenstore.(*TokenStore) router := NewRouter() diff --git a/vault/token_store.go b/vault/token_store.go index 0db5547e56ca..48a1d1852494 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -109,19 +109,10 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) t.policyLookupFunc = c.policyStore.GetPolicy } - // Setup the salt - salt, err := salt.NewSalt(view, &salt.Config{ - HashFunc: salt.SHA1Hash, - }) - if err != nil { - return nil, err - } - t.salt = salt - t.tokenLocks = map[string]*sync.RWMutex{} // Create 256 locks - if err = locksutil.CreateLocks(t.tokenLocks, 256); err != nil { + if err := locksutil.CreateLocks(t.tokenLocks, 256); err != nil { return nil, fmt.Errorf("failed to create locks: %v", err) } @@ -136,6 +127,15 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) "revoke-orphan/*", "accessors*", }, + + // Most token store items are local since tokens are local, but a + // notable exception is roles + LocalStorage: []string{ + lookupPrefix, + accessorPrefix, + parentPrefix, + "salt", + }, }, Paths: []*framework.Path{ @@ -467,6 +467,8 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) HelpDescription: strings.TrimSpace(tokenTidyDesc), }, }, + + Init: t.Initialize, } t.Backend.Setup(config) @@ -474,6 +476,19 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) return t, nil } +func (ts *TokenStore) Initialize() error { + // Setup the salt + salt, err := salt.NewSalt(ts.view, &salt.Config{ + HashFunc: salt.SHA1Hash, + }) + if err != nil { + return err + } + ts.salt = salt + + return nil +} + // TokenEntry is used to represent a given token type TokenEntry struct { // ID of this entry, generally a random UUID @@ -1085,7 +1100,7 @@ func (ts *TokenStore) lookupBySaltedAccessor(saltedAccessor string) (accessorEnt return aEntry, fmt.Errorf("failed to read index using accessor: %s", err) } if entry == nil { - return aEntry, &StatusBadRequest{Err: "invalid accessor"} + return aEntry, &logical.StatusBadRequest{Err: "invalid accessor"} } err = jsonutil.DecodeJSON(entry.Value, &aEntry) @@ -1225,7 +1240,7 @@ func (ts *TokenStore) handleUpdateLookupAccessor(req *logical.Request, data *fra if accessor == "" { accessor = data.Get("urlaccessor").(string) if accessor == "" { - return nil, &StatusBadRequest{Err: "missing accessor"} + return nil, &logical.StatusBadRequest{Err: "missing accessor"} } urlaccessor = true } @@ -1279,7 +1294,7 @@ func (ts *TokenStore) handleUpdateRevokeAccessor(req *logical.Request, data *fra if accessor == "" { accessor = data.Get("urlaccessor").(string) if accessor == "" { - return nil, &StatusBadRequest{Err: "missing accessor"} + return nil, &logical.StatusBadRequest{Err: "missing accessor"} } urlaccessor = true } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 745c6e3541bc..69e5b7ca4512 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -437,6 +437,9 @@ func TestTokenStore_CreateLookup(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } + if err := ts2.Initialize(); err != nil { + t.Fatalf("err: %v", err) + } // Should still match out, err = ts2.Lookup(ent.ID) @@ -476,6 +479,9 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } + if err := ts2.Initialize(); err != nil { + t.Fatalf("err: %v", err) + } // Should still match out, err = ts2.Lookup(ent.ID) From 3588b233f8679c14062dfa79febe9f31ff8dde22 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 16 Feb 2017 22:53:25 -0500 Subject: [PATCH 2/2] Add a bit more porting --- vault/request_forwarding.go | 41 +++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index f29cb95b0bd9..04f65cef0b69 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -27,6 +27,9 @@ func (c *Core) startForwarding() error { // Clean up in case we have transitioned from a client to a server c.clearForwardingClients() + // Resolve locally to avoid races + ha := c.ha != nil + // Get our base handler (for our RPC server) and our wrapped handler (for // straight HTTP/2 forwarding) baseHandler, wrappedHandler := c.clusterHandlerSetupFunc() @@ -43,10 +46,13 @@ func (c *Core) startForwarding() error { // Create our RPC server and register the request handler server c.rpcServer = grpc.NewServer() - RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{ - core: c, - handler: baseHandler, - }) + + if ha { + RegisterRequestForwardingServer(c.rpcServer, &forwardedRequestRPCServer{ + core: c, + handler: baseHandler, + }) + } // Create the HTTP/2 server that will be shared by both RPC and regular // duties. Doing it this way instead of listening via the server and gRPC @@ -82,6 +88,7 @@ func (c *Core) startForwarding() error { // Wrap the listener with TLS tlsLn := tls.NewListener(tcpLn, tlsConfig) + defer tlsLn.Close() if c.logger.IsInfo() { c.logger.Info("core/startClusterListener: serving cluster requests", "cluster_listen_address", tlsLn.Addr()) @@ -89,7 +96,6 @@ func (c *Core) startForwarding() error { for { if atomic.LoadUint32(&shutdown) > 0 { - tlsLn.Close() return } @@ -100,10 +106,11 @@ func (c *Core) startForwarding() error { // Accept the connection conn, err := tlsLn.Accept() + if conn != nil { + // Always defer although it may be closed ahead of time + defer conn.Close() + } if err != nil { - if conn != nil { - conn.Close() - } continue } @@ -123,19 +130,29 @@ func (c *Core) startForwarding() error { switch tlsConn.ConnectionState().NegotiatedProtocol { case "h2": + if !ha { + conn.Close() + continue + } + c.logger.Debug("core/startClusterListener/Accept: got h2 connection") go fws.ServeConn(conn, &http2.ServeConnOpts{ Handler: wrappedHandler, }) case "req_fw_sb-act_v1": + if !ha { + conn.Close() + continue + } + c.logger.Debug("core/startClusterListener/Accept: got req_fw_sb-act_v1 connection") go fws.ServeConn(conn, &http2.ServeConnOpts{ Handler: c.rpcServer, }) default: - c.logger.Debug("core/startClusterListener/Accept: unknown negotiated protocol") + c.logger.Debug("core: unknown negotiated protocol on cluster port") conn.Close() continue } @@ -154,8 +171,9 @@ func (c *Core) startForwarding() error { <-c.clusterListenerShutdownCh // Stop the RPC server + c.logger.Info("core: shutting down forwarding rpc listeners") c.rpcServer.Stop() - c.logger.Info("core/startClusterListener: shutting down listeners") + c.logger.Info("core: forwarding rpc listeners stopped") // Set the shutdown flag. This will cause the listeners to shut down // within the deadline in clusterListenerAcceptDeadline @@ -163,7 +181,7 @@ func (c *Core) startForwarding() error { // Wait for them all to shut down shutdownWg.Wait() - c.logger.Info("core/startClusterListener: listeners successfully shut down") + c.logger.Info("core: rpc listeners successfully shut down") // Tell the main thread that shutdown is done. c.clusterListenerShutdownSuccessCh <- struct{}{} @@ -223,6 +241,7 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error { // It's not really insecure, but we have to dial manually to get the // ALPN header right. It's just "insecure" because GRPC isn't managing // the TLS state. + ctx, cancelFunc := context.WithCancel(context.Background()) c.rpcClientConnCancelFunc = cancelFunc c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "")), grpc.WithInsecure())