From 21274610a963c37431b1d8eed7b21641fafb5182 Mon Sep 17 00:00:00 2001 From: Mike Jarmy Date: Tue, 16 Jun 2020 14:12:22 -0400 Subject: [PATCH] Test Shamir-to-Transit and Transit-to-Shamir Seal Migration for post-1.4 Vault. (#9214) * move adjustForSealMigration to vault package * fix adjustForSealMigration * begin working on new seal migration test * create shamir seal migration test * refactor testhelpers * add VerifyRaftConfiguration to testhelpers * stub out TestTransit * Revert "refactor testhelpers" This reverts commit 39593defd0d4c6fd79aedfd37df6298391abb9db. * get shamir test working again * stub out transit join * work on transit join * remove debug code * initTransit now works with raft join * runTransit works with inmem * work on runTransit with raft * runTransit works with raft * cleanup tests * TestSealMigration_TransitToShamir_Pre14 * TestSealMigration_ShamirToTransit_Pre14 * split for pre-1.4 testing * add simple tests for transit and shamir * fix typo in test suite * debug wrapper type * test debug * test-debug * refactor core migration * Revert "refactor core migration" This reverts commit a776452d32a9dca7a51e3df4a76b9234d8c0c7ce. * begin refactor of adjustForSealMigration * fix bug in adjustForSealMigration * clean up tests * clean up core refactoring * fix bug in shamir->transit migration * stub out test that brings individual nodes up and down * refactor NewTestCluster * pass listeners into newCore() * simplify cluster address setup * simplify extra test core setup * refactor TestCluster for readability * refactor TestCluster for readability * refactor TestCluster for readability * add shutdown func to TestCore * add cleanup func to TestCore * create RestartCore * stub out TestSealMigration_ShamirToTransit_Post14 * refactor address handling in NewTestCluster * fix listener setup in newCore() * remove unnecessary lock from setSealsForMigration() * rename sealmigration test package * use ephemeral ports below 30000 * work on post-1.4 migration testing * clean up pre-1.4 test * TestSealMigration_ShamirToTransit_Post14 works for non-raft * work on raft TestSealMigration_ShamirToTransit_Post14 * clean up test code * refactor TestClusterCore * clean up TestClusterCore * stub out some temporary tests * use HardcodedServerAddressProvider in seal migration tests * work on raft for TestSealMigration_ShamirToTransit_Post14 * always use hardcoded raft address provider in seal migration tests * debug TestSealMigration_ShamirToTransit_Post14 * fix bug in RestartCore * remove debug code * TestSealMigration_ShamirToTransit_Post14 works now * clean up debug code * clean up tests * cleanup tests * refactor test code * stub out TestSealMigration_TransitToShamir_Post14 * set seals properly for transit->shamir migration * migrateFromTransitToShamir_Post14 works for inmem * migrateFromTransitToShamir_Post14 works for raft * use base ports per-test * fix seal verification test code * simplify seal migration test suite * simplify test suite * cleanup test suite * use explicit ports below 30000 * simplify use of numTestCores * Update vault/external_tests/sealmigration/seal_migration_test.go Co-authored-by: Calvin Leung Huang * Update vault/external_tests/sealmigration/seal_migration_test.go Co-authored-by: Calvin Leung Huang * clean up imports * rename to StartCore() * Update vault/testing.go Co-authored-by: Calvin Leung Huang * simplify test suite * clean up tests Co-authored-by: Calvin Leung Huang --- helper/testhelpers/testhelpers.go | 86 +- .../teststorage/teststorage_reusable.go | 30 +- .../seal_migration_pre14_test.go | 8 +- .../sealmigration/seal_migration_test.go | 473 ++++++++-- vault/testing.go | 837 +++++++++++------- vault/testing_util.go | 6 +- 6 files changed, 955 insertions(+), 485 deletions(-) diff --git a/helper/testhelpers/testhelpers.go b/helper/testhelpers/testhelpers.go index b9aff79f3b14..6bae6e68708e 100644 --- a/helper/testhelpers/testhelpers.go +++ b/helper/testhelpers/testhelpers.go @@ -412,16 +412,9 @@ func (p *TestRaftServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftlib } func RaftClusterJoinNodes(t testing.T, cluster *vault.TestCluster) { - raftClusterJoinNodes(t, cluster, false) -} - -func RaftClusterJoinNodesWithStoredKeys(t testing.T, cluster *vault.TestCluster) { - raftClusterJoinNodes(t, cluster, true) -} - -func raftClusterJoinNodes(t testing.T, cluster *vault.TestCluster, useStoredKeys bool) { addressProvider := &TestRaftServerAddressProvider{Cluster: cluster} + atomic.StoreUint32(&vault.UpdateClusterAddrForTests, 1) leader := cluster.Cores[0] @@ -430,11 +423,7 @@ func raftClusterJoinNodes(t testing.T, cluster *vault.TestCluster, useStoredKeys { EnsureCoreSealed(t, leader) leader.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider) - if useStoredKeys { - cluster.UnsealCoreWithStoredKeys(t, leader) - } else { - cluster.UnsealCore(t, leader) - } + cluster.UnsealCore(t, leader) vault.TestWaitActive(t, leader.Core) } @@ -454,37 +443,12 @@ func raftClusterJoinNodes(t testing.T, cluster *vault.TestCluster, useStoredKeys t.Fatal(err) } - if useStoredKeys { - // For autounseal, the raft backend is not initialized right away - // after the join. We need to wait briefly before we can unseal. - awaitUnsealWithStoredKeys(t, core) - } else { - cluster.UnsealCore(t, core) - } + cluster.UnsealCore(t, core) } WaitForNCoresUnsealed(t, cluster, len(cluster.Cores)) } -func awaitUnsealWithStoredKeys(t testing.T, core *vault.TestClusterCore) { - - timeout := time.Now().Add(30 * time.Second) - for { - if time.Now().After(timeout) { - t.Fatal("raft join: timeout waiting for core to unseal") - } - // Its actually ok for an error to happen here the first couple of - // times -- it means the raft join hasn't gotten around to initializing - // the backend yet. - err := core.UnsealWithStoredKeys(context.Background()) - if err == nil { - return - } - core.Logger().Warn("raft join: failed to unseal core", "error", err) - time.Sleep(time.Second) - } -} - // HardcodedServerAddressProvider is a ServerAddressProvider that uses // a hardcoded map of raft node addresses. // @@ -505,11 +469,11 @@ func (p *HardcodedServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftli // NewHardcodedServerAddressProvider is a convenience function that makes a // ServerAddressProvider from a given cluster address base port. -func NewHardcodedServerAddressProvider(cluster *vault.TestCluster, baseClusterPort int) raftlib.ServerAddressProvider { +func NewHardcodedServerAddressProvider(numCores, baseClusterPort int) raftlib.ServerAddressProvider { entries := make(map[raftlib.ServerID]raftlib.ServerAddress) - for i := 0; i < len(cluster.Cores); i++ { + for i := 0; i < numCores; i++ { id := fmt.Sprintf("core-%d", i) addr := fmt.Sprintf("127.0.0.1:%d", baseClusterPort+i) entries[raftlib.ServerID(id)] = raftlib.ServerAddress(addr) @@ -520,17 +484,6 @@ func NewHardcodedServerAddressProvider(cluster *vault.TestCluster, baseClusterPo } } -// SetRaftAddressProviders sets a ServerAddressProvider for all the nodes in a -// cluster. -func SetRaftAddressProviders(t testing.T, cluster *vault.TestCluster, provider raftlib.ServerAddressProvider) { - - atomic.StoreUint32(&vault.UpdateClusterAddrForTests, 1) - - for _, core := range cluster.Cores { - core.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(provider) - } -} - // VerifyRaftConfiguration checks that we have a valid raft configuration, i.e. // the correct number of servers, having the correct NodeIDs, and exactly one // leader. @@ -565,6 +518,35 @@ func VerifyRaftConfiguration(core *vault.TestClusterCore, numCores int) error { return nil } +// AwaitLeader waits for one of the cluster's nodes to become leader. +func AwaitLeader(t testing.T, cluster *vault.TestCluster) (int, error) { + + timeout := time.Now().Add(30 * time.Second) + for { + if time.Now().After(timeout) { + break + } + + for i, core := range cluster.Cores { + if core.Core.Sealed() { + continue + } + + isLeader, _, _, err := core.Leader() + if err != nil { + t.Fatal(err) + } + if isLeader { + return i, nil + } + } + + time.Sleep(time.Second) + } + + return 0, fmt.Errorf("timeout waiting leader") +} + func GenerateDebugLogs(t testing.T, client *api.Client) chan struct{} { t.Helper() diff --git a/helper/testhelpers/teststorage/teststorage_reusable.go b/helper/testhelpers/teststorage/teststorage_reusable.go index bb78a6d972ca..3546fb7e528b 100644 --- a/helper/testhelpers/teststorage/teststorage_reusable.go +++ b/helper/testhelpers/teststorage/teststorage_reusable.go @@ -8,6 +8,7 @@ import ( "github.com/mitchellh/go-testing-interface" hclog "github.com/hashicorp/go-hclog" + raftlib "github.com/hashicorp/raft" "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/vault" ) @@ -73,7 +74,10 @@ func MakeReusableStorage(t testing.T, logger hclog.Logger, bundle *vault.Physica // MakeReusableRaftStorage makes a physical raft backend that can be re-used // across multiple test clusters in sequence. -func MakeReusableRaftStorage(t testing.T, logger hclog.Logger, numCores int) (ReusableStorage, StorageCleanup) { +func MakeReusableRaftStorage( + t testing.T, logger hclog.Logger, numCores int, + addressProvider raftlib.ServerAddressProvider, +) (ReusableStorage, StorageCleanup) { raftDirs := make([]string, numCores) for i := 0; i < numCores; i++ { @@ -87,17 +91,14 @@ func MakeReusableRaftStorage(t testing.T, logger hclog.Logger, numCores int) (Re conf.DisablePerformanceStandby = true opts.KeepStandbysSealed = true opts.PhysicalFactory = func(t testing.T, coreIdx int, logger hclog.Logger) *vault.PhysicalBackendBundle { - return makeReusableRaftBackend(t, coreIdx, logger, raftDirs[coreIdx]) + return makeReusableRaftBackend(t, coreIdx, logger, raftDirs[coreIdx], addressProvider) } }, // Close open files being used by raft. Cleanup: func(t testing.T, cluster *vault.TestCluster) { - for _, core := range cluster.Cores { - raftStorage := core.UnderlyingRawStorage.(*raft.RaftBackend) - if err := raftStorage.Close(); err != nil { - t.Fatal(err) - } + for i := 0; i < len(cluster.Cores); i++ { + CloseRaftStorage(t, cluster, i) } }, } @@ -111,6 +112,14 @@ func MakeReusableRaftStorage(t testing.T, logger hclog.Logger, numCores int) (Re return storage, cleanup } +// CloseRaftStorage closes open files being used by raft. +func CloseRaftStorage(t testing.T, cluster *vault.TestCluster, idx int) { + raftStorage := cluster.Cores[idx].UnderlyingRawStorage.(*raft.RaftBackend) + if err := raftStorage.Close(); err != nil { + t.Fatal(err) + } +} + func makeRaftDir(t testing.T) string { raftDir, err := ioutil.TempDir("", "vault-raft-") if err != nil { @@ -120,7 +129,10 @@ func makeRaftDir(t testing.T) string { return raftDir } -func makeReusableRaftBackend(t testing.T, coreIdx int, logger hclog.Logger, raftDir string) *vault.PhysicalBackendBundle { +func makeReusableRaftBackend( + t testing.T, coreIdx int, logger hclog.Logger, raftDir string, + addressProvider raftlib.ServerAddressProvider, +) *vault.PhysicalBackendBundle { nodeID := fmt.Sprintf("core-%d", coreIdx) conf := map[string]string{ @@ -134,6 +146,8 @@ func makeReusableRaftBackend(t testing.T, coreIdx int, logger hclog.Logger, raft t.Fatal(err) } + backend.(*raft.RaftBackend).SetServerAddressProvider(addressProvider) + return &vault.PhysicalBackendBundle{ Backend: backend, } diff --git a/vault/external_tests/sealmigration/seal_migration_pre14_test.go b/vault/external_tests/sealmigration/seal_migration_pre14_test.go index 6f72449da118..cdfc368f61ea 100644 --- a/vault/external_tests/sealmigration/seal_migration_pre14_test.go +++ b/vault/external_tests/sealmigration/seal_migration_pre14_test.go @@ -25,7 +25,7 @@ import ( func TestSealMigration_TransitToShamir_Pre14(t *testing.T) { // Note that we do not test integrated raft storage since this is // a pre-1.4 test. - testVariousBackends(t, testSealMigrationTransitToShamir_Pre14, false) + testVariousBackends(t, testSealMigrationTransitToShamir_Pre14, basePort_TransitToShamir_Pre14, false) } func testSealMigrationTransitToShamir_Pre14( @@ -42,7 +42,11 @@ func testSealMigrationTransitToShamir_Pre14( tss.MakeKey(t, "transit-seal-key") // Initialize the backend with transit. - rootToken, recoveryKeys, transitSeal := initializeTransit(t, logger, storage, basePort, tss) + cluster, _, transitSeal := initializeTransit(t, logger, storage, basePort, tss) + rootToken, recoveryKeys := cluster.RootToken, cluster.RecoveryKeys + cluster.EnsureCoresSealed(t) + storage.Cleanup(t, cluster) + cluster.Cleanup() // Migrate the backend from transit to shamir migrateFromTransitToShamir_Pre14(t, logger, storage, basePort, tss, transitSeal, rootToken, recoveryKeys) diff --git a/vault/external_tests/sealmigration/seal_migration_test.go b/vault/external_tests/sealmigration/seal_migration_test.go index 78506b3023ea..68fa55e665e3 100644 --- a/vault/external_tests/sealmigration/seal_migration_test.go +++ b/vault/external_tests/sealmigration/seal_migration_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "sync/atomic" "testing" "time" @@ -12,10 +13,12 @@ import ( "github.com/hashicorp/go-hclog" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/testhelpers" sealhelper "github.com/hashicorp/vault/helper/testhelpers/seal" "github.com/hashicorp/vault/helper/testhelpers/teststorage" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/vault" ) @@ -24,11 +27,16 @@ const ( numTestCores = 5 keyShares = 3 keyThreshold = 3 + + basePort_ShamirToTransit_Pre14 = 20000 + basePort_TransitToShamir_Pre14 = 21000 + basePort_ShamirToTransit_Post14 = 22000 + basePort_TransitToShamir_Post14 = 23000 ) type testFunc func(t *testing.T, logger hclog.Logger, storage teststorage.ReusableStorage, basePort int) -func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { +func testVariousBackends(t *testing.T, tf testFunc, basePort int, includeRaft bool) { logger := logging.NewVaultLogger(hclog.Debug).Named(t.Name()) @@ -39,7 +47,7 @@ func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { storage, cleanup := teststorage.MakeReusableStorage( t, logger, teststorage.MakeInmemBackend(t, logger)) defer cleanup() - tf(t, logger, storage, 20000) + tf(t, logger, storage, basePort+100) }) t.Run("file", func(t *testing.T) { @@ -49,7 +57,7 @@ func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { storage, cleanup := teststorage.MakeReusableStorage( t, logger, teststorage.MakeFileBackend(t, logger)) defer cleanup() - tf(t, logger, storage, 20020) + tf(t, logger, storage, basePort+200) }) t.Run("consul", func(t *testing.T) { @@ -59,7 +67,7 @@ func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { storage, cleanup := teststorage.MakeReusableStorage( t, logger, teststorage.MakeConsulBackend(t, logger)) defer cleanup() - tf(t, logger, storage, 20040) + tf(t, logger, storage, basePort+300) }) if includeRaft { @@ -67,9 +75,14 @@ func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { t.Parallel() logger := logger.Named("raft") - storage, cleanup := teststorage.MakeReusableRaftStorage(t, logger, numTestCores) + raftBasePort := basePort + 400 + + atomic.StoreUint32(&vault.UpdateClusterAddrForTests, 1) + addressProvider := testhelpers.NewHardcodedServerAddressProvider(numTestCores, raftBasePort+10) + + storage, cleanup := teststorage.MakeReusableRaftStorage(t, logger, numTestCores, addressProvider) defer cleanup() - tf(t, logger, storage, 20060) + tf(t, logger, storage, raftBasePort) }) } } @@ -80,7 +93,7 @@ func testVariousBackends(t *testing.T, tf testFunc, includeRaft bool) { func TestSealMigration_ShamirToTransit_Pre14(t *testing.T) { // Note that we do not test integrated raft storage since this is // a pre-1.4 test. - testVariousBackends(t, testSealMigrationShamirToTransit_Pre14, false) + testVariousBackends(t, testSealMigrationShamirToTransit_Pre14, basePort_ShamirToTransit_Pre14, false) } func testSealMigrationShamirToTransit_Pre14( @@ -88,7 +101,11 @@ func testSealMigrationShamirToTransit_Pre14( storage teststorage.ReusableStorage, basePort int) { // Initialize the backend using shamir - rootToken, barrierKeys := initializeShamir(t, logger, storage, basePort) + cluster, _ := initializeShamir(t, logger, storage, basePort) + rootToken, barrierKeys := cluster.RootToken, cluster.BarrierKeys + cluster.EnsureCoresSealed(t) + storage.Cleanup(t, cluster) + cluster.Cleanup() // Create the transit server. tss := sealhelper.NewTransitSealServer(t) @@ -117,7 +134,8 @@ func migrateFromShamirToTransit_Pre14( var transitSeal vault.Seal var conf = vault.CoreConfig{ - Logger: logger.Named("migrateFromShamirToTransit"), + Logger: logger.Named("migrateFromShamirToTransit"), + DisablePerformanceStandby: true, } var opts = vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, @@ -127,9 +145,6 @@ func migrateFromShamirToTransit_Pre14( SkipInit: true, // N.B. Providing a transit seal puts us in migration mode. SealFunc: func() vault.Seal { - // Each core will create its own transit seal here. Later - // on it won't matter which one of these we end up using, since - // they were all created from the same transit key. transitSeal = tss.MakeSeal(t, "transit-seal-key") return transitSeal }, @@ -143,18 +158,16 @@ func migrateFromShamirToTransit_Pre14( }() leader := cluster.Cores[0] - client := leader.Client - client.SetToken(rootToken) + leader.Client.SetToken(rootToken) // Unseal and migrate to Transit. - unsealMigrate(t, client, recoveryKeys, true) + unsealMigrate(t, leader.Client, recoveryKeys, true) - // Wait for migration to finish. Sadly there is no callback, and the - // test will fail later on if we don't do this. - time.Sleep(10 * time.Second) + // Wait for migration to finish. + awaitMigration(t, leader.Client) // Read the secret - secret, err := client.Logical().Read("secret/foo") + secret, err := leader.Client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } @@ -175,6 +188,254 @@ func migrateFromShamirToTransit_Pre14( return transitSeal } +// TestSealMigration_ShamirToTransit_Post14 tests shamir-to-transit seal +// migration, using the post-1.4 method of bring individual nodes in the cluster +// to do the migration. +func TestSealMigration_ShamirToTransit_Post14(t *testing.T) { + testVariousBackends(t, testSealMigrationShamirToTransit_Post14, basePort_ShamirToTransit_Post14, true) +} + +func testSealMigrationShamirToTransit_Post14( + t *testing.T, logger hclog.Logger, + storage teststorage.ReusableStorage, basePort int) { + + // Initialize the backend using shamir + cluster, opts := initializeShamir(t, logger, storage, basePort) + + // Create the transit server. + tss := sealhelper.NewTransitSealServer(t) + defer func() { + tss.EnsureCoresSealed(t) + tss.Cleanup() + }() + tss.MakeKey(t, "transit-seal-key") + + // Migrate the backend from shamir to transit. + transitSeal := migrateFromShamirToTransit_Post14(t, logger, storage, basePort, tss, cluster, opts) + cluster.EnsureCoresSealed(t) + + storage.Cleanup(t, cluster) + cluster.Cleanup() + + // Run the backend with transit. + runTransit(t, logger, storage, basePort, cluster.RootToken, transitSeal) +} + +func migrateFromShamirToTransit_Post14( + t *testing.T, logger hclog.Logger, + storage teststorage.ReusableStorage, basePort int, + tss *sealhelper.TransitSealServer, + cluster *vault.TestCluster, opts *vault.TestClusterOptions, +) vault.Seal { + + // N.B. Providing a transit seal puts us in migration mode. + var transitSeal vault.Seal + opts.SealFunc = func() vault.Seal { + transitSeal = tss.MakeSeal(t, "transit-seal-key") + return transitSeal + } + modifyCoreConfig := func(tcc *vault.TestClusterCore) {} + + // Restart each follower with the new config, and migrate to Transit. + // Note that the barrier keys are being used as recovery keys. + leaderIdx := migratePost14( + t, logger, storage, cluster, opts, + cluster.RootToken, cluster.BarrierKeys, + migrateShamirToTransit, modifyCoreConfig) + leader := cluster.Cores[leaderIdx] + + // Read the secret + secret, err := leader.Client.Logical().Read("secret/foo") + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(secret.Data, map[string]interface{}{"zork": "quux"}); len(diff) > 0 { + t.Fatal(diff) + } + + // Make sure the seal configs were updated correctly. + b, r, err := leader.Core.PhysicalSealConfigs(context.Background()) + if err != nil { + t.Fatal(err) + } + verifyBarrierConfig(t, b, wrapping.Transit, 1, 1, 1) + verifyBarrierConfig(t, r, wrapping.Shamir, keyShares, keyThreshold, 0) + + return transitSeal +} + +// TestSealMigration_TransitToShamir_Post14 tests transit-to-shamir seal +// migration, using the post-1.4 method of bring individual nodes in the +// cluster to do the migration. +func TestSealMigration_TransitToShamir_Post14(t *testing.T) { + testVariousBackends(t, testSealMigrationTransitToShamir_Post14, basePort_TransitToShamir_Post14, true) +} + +func testSealMigrationTransitToShamir_Post14( + t *testing.T, logger hclog.Logger, + storage teststorage.ReusableStorage, basePort int) { + + // Create the transit server. + tss := sealhelper.NewTransitSealServer(t) + defer func() { + if tss != nil { + tss.Cleanup() + } + }() + tss.MakeKey(t, "transit-seal-key") + + // Initialize the backend with transit. + cluster, opts, transitSeal := initializeTransit(t, logger, storage, basePort, tss) + rootToken, recoveryKeys := cluster.RootToken, cluster.RecoveryKeys + + // Migrate the backend from transit to shamir + migrateFromTransitToShamir_Post14(t, logger, storage, basePort, tss, transitSeal, cluster, opts) + cluster.EnsureCoresSealed(t) + storage.Cleanup(t, cluster) + cluster.Cleanup() + + // Now that migration is done, we can nuke the transit server, since we + // can unseal without it. + tss.Cleanup() + tss = nil + + // Run the backend with shamir. Note that the recovery keys are now the + // barrier keys. + runShamir(t, logger, storage, basePort, rootToken, recoveryKeys) +} + +func migrateFromTransitToShamir_Post14( + t *testing.T, logger hclog.Logger, + storage teststorage.ReusableStorage, basePort int, + tss *sealhelper.TransitSealServer, transitSeal vault.Seal, + cluster *vault.TestCluster, opts *vault.TestClusterOptions) { + + opts.SealFunc = nil + modifyCoreConfig := func(tcc *vault.TestClusterCore) { + // Nil out the seal so it will be initialized as shamir. + tcc.CoreConfig.Seal = nil + + // N.B. Providing an UnwrapSeal puts us in migration mode. This is the + // equivalent of doing the following in HCL: + // seal "transit" { + // // ... + // disabled = "true" + // } + tcc.CoreConfig.UnwrapSeal = transitSeal + } + + // Restart each follower with the new config, and migrate to Shamir. + leaderIdx := migratePost14( + t, logger, storage, cluster, opts, + cluster.RootToken, cluster.RecoveryKeys, + migrateTransitToShamir, modifyCoreConfig) + leader := cluster.Cores[leaderIdx] + + // Read the secret + secret, err := leader.Client.Logical().Read("secret/foo") + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(secret.Data, map[string]interface{}{"zork": "quux"}); len(diff) > 0 { + t.Fatal(diff) + } + + // Make sure the seal configs were updated correctly. + b, r, err := cluster.Cores[0].Core.PhysicalSealConfigs(context.Background()) + if err != nil { + t.Fatal(err) + } + verifyBarrierConfig(t, b, wrapping.Shamir, keyShares, keyThreshold, 1) + if r != nil { + t.Fatalf("expected nil recovery config, got: %#v", r) + } +} + +type migrationDirection int + +const ( + migrateShamirToTransit migrationDirection = iota + migrateTransitToShamir +) + +func migratePost14( + t *testing.T, logger hclog.Logger, + storage teststorage.ReusableStorage, + cluster *vault.TestCluster, opts *vault.TestClusterOptions, + rootToken string, recoveryKeys [][]byte, + migrate migrationDirection, + modifyCoreConfig func(*vault.TestClusterCore), +) int { + + // Restart each follower with the new config, and migrate. + for i := 1; i < len(cluster.Cores); i++ { + cluster.StopCore(t, i) + if storage.IsRaft { + teststorage.CloseRaftStorage(t, cluster, i) + } + modifyCoreConfig(cluster.Cores[i]) + cluster.StartCore(t, i, opts) + + cluster.Cores[i].Client.SetToken(rootToken) + unsealMigrate(t, cluster.Cores[i].Client, recoveryKeys, true) + time.Sleep(5 * time.Second) + } + + // Bring down the leader + cluster.StopCore(t, 0) + if storage.IsRaft { + teststorage.CloseRaftStorage(t, cluster, 0) + } + + // Wait for the followers to establish a new leader + leaderIdx, err := testhelpers.AwaitLeader(t, cluster) + if err != nil { + t.Fatal(err) + } + if leaderIdx == 0 { + t.Fatalf("Core 0 cannot be the leader right now") + } + leader := cluster.Cores[leaderIdx] + leader.Client.SetToken(rootToken) + + // Bring core 0 back up + cluster.StartCore(t, 0, opts) + cluster.Cores[0].Client.SetToken(rootToken) + + // TODO look into why this is different for different migration directions, + // and why it is swapped for raft. + switch migrate { + case migrateShamirToTransit: + if storage.IsRaft { + unsealMigrate(t, cluster.Cores[0].Client, recoveryKeys, true) + } else { + unseal(t, cluster.Cores[0].Client, recoveryKeys) + } + case migrateTransitToShamir: + if storage.IsRaft { + unseal(t, cluster.Cores[0].Client, recoveryKeys) + } else { + unsealMigrate(t, cluster.Cores[0].Client, recoveryKeys, true) + } + default: + t.Fatalf("unreachable") + } + + // Wait for migration to finish. + awaitMigration(t, leader.Client) + + // This is apparently necessary for the raft cluster to get itself + // situated. + if storage.IsRaft { + time.Sleep(15 * time.Second) + if err := testhelpers.VerifyRaftConfiguration(leader, len(cluster.Cores)); err != nil { + t.Fatal(err) + } + } + + return leaderIdx +} + func unsealMigrate(t *testing.T, client *api.Client, keys [][]byte, transitServerAvailable bool) { for i, key := range keys { @@ -218,6 +479,53 @@ func unsealMigrate(t *testing.T, client *api.Client, keys [][]byte, transitServe } } +// awaitMigration waits for migration to finish. +func awaitMigration(t *testing.T, client *api.Client) { + + timeout := time.Now().Add(60 * time.Second) + for { + if time.Now().After(timeout) { + break + } + + resp, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if !resp.Migration { + return + } + + time.Sleep(time.Second) + } + + t.Fatalf("migration did not complete.") +} + +func unseal(t *testing.T, client *api.Client, keys [][]byte) { + + for i, key := range keys { + + resp, err := client.Sys().UnsealWithOptions(&api.UnsealOpts{ + Key: base64.StdEncoding.EncodeToString(key), + }) + if i < keyThreshold-1 { + // Not enough keys have been provided yet. + if err != nil { + t.Fatal(err) + } + } else { + if err != nil { + t.Fatal(err) + } + if resp == nil || resp.Sealed { + t.Fatalf("expected unsealed state; got %#v", resp) + } + break + } + } +} + // verifyBarrierConfig verifies that a barrier configuration is correct. func verifyBarrierConfig(t *testing.T, cfg *vault.SealConfig, sealType string, shares, threshold, stored int) { t.Helper() @@ -238,13 +546,14 @@ func verifyBarrierConfig(t *testing.T, cfg *vault.SealConfig, sealType string, s // initializeShamir initializes a brand new backend storage with Shamir. func initializeShamir( t *testing.T, logger hclog.Logger, - storage teststorage.ReusableStorage, basePort int) (string, [][]byte) { + storage teststorage.ReusableStorage, basePort int) (*vault.TestCluster, *vault.TestClusterOptions) { var baseClusterPort = basePort + 10 // Start the cluster var conf = vault.CoreConfig{ - Logger: logger.Named("initializeShamir"), + Logger: logger.Named("initializeShamir"), + DisablePerformanceStandby: true, } var opts = vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, @@ -255,17 +564,13 @@ func initializeShamir( storage.Setup(&conf, &opts) cluster := vault.NewTestCluster(t, &conf, &opts) cluster.Start() - defer func() { - storage.Cleanup(t, cluster) - cluster.Cleanup() - }() leader := cluster.Cores[0] client := leader.Client // Unseal if storage.IsRaft { - testhelpers.RaftClusterJoinNodes(t, cluster) + joinRaftFollowers(t, cluster, false) if err := testhelpers.VerifyRaftConfiguration(leader, len(cluster.Cores)); err != nil { t.Fatal(err) } @@ -282,10 +587,7 @@ func initializeShamir( t.Fatal(err) } - // Seal the cluster - cluster.EnsureCoresSealed(t) - - return cluster.RootToken, cluster.BarrierKeys + return cluster, &opts } // runShamir uses a pre-populated backend storage with Shamir. @@ -298,7 +600,8 @@ func runShamir( // Start the cluster var conf = vault.CoreConfig{ - Logger: logger.Named("runShamir"), + Logger: logger.Named("runShamir"), + DisablePerformanceStandby: true, } var opts = vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, @@ -322,17 +625,12 @@ func runShamir( // Unseal cluster.BarrierKeys = barrierKeys if storage.IsRaft { - provider := testhelpers.NewHardcodedServerAddressProvider(cluster, baseClusterPort) - testhelpers.SetRaftAddressProviders(t, cluster, provider) - for _, core := range cluster.Cores { cluster.UnsealCore(t, core) } - // This is apparently necessary for the raft cluster to get itself // situated. time.Sleep(15 * time.Second) - if err := testhelpers.VerifyRaftConfiguration(leader, len(cluster.Cores)); err != nil { t.Fatal(err) } @@ -358,7 +656,7 @@ func runShamir( func initializeTransit( t *testing.T, logger hclog.Logger, storage teststorage.ReusableStorage, basePort int, - tss *sealhelper.TransitSealServer) (string, [][]byte, vault.Seal) { + tss *sealhelper.TransitSealServer) (*vault.TestCluster, *vault.TestClusterOptions, vault.Seal) { var transitSeal vault.Seal @@ -366,7 +664,8 @@ func initializeTransit( // Start the cluster var conf = vault.CoreConfig{ - Logger: logger.Named("initializeTransit"), + Logger: logger.Named("initializeTransit"), + DisablePerformanceStandby: true, } var opts = vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, @@ -374,9 +673,6 @@ func initializeTransit( BaseListenAddress: fmt.Sprintf("127.0.0.1:%d", basePort), BaseClusterListenPort: baseClusterPort, SealFunc: func() vault.Seal { - // Each core will create its own transit seal here. Later - // on it won't matter which one of these we end up using, since - // they were all created from the same transit key. transitSeal = tss.MakeSeal(t, "transit-seal-key") return transitSeal }, @@ -384,17 +680,14 @@ func initializeTransit( storage.Setup(&conf, &opts) cluster := vault.NewTestCluster(t, &conf, &opts) cluster.Start() - defer func() { - storage.Cleanup(t, cluster) - cluster.Cleanup() - }() leader := cluster.Cores[0] client := leader.Client // Join raft if storage.IsRaft { - testhelpers.RaftClusterJoinNodesWithStoredKeys(t, cluster) + joinRaftFollowers(t, cluster, true) + if err := testhelpers.VerifyRaftConfiguration(leader, len(cluster.Cores)); err != nil { t.Fatal(err) } @@ -409,10 +702,7 @@ func initializeTransit( t.Fatal(err) } - // Seal the cluster - cluster.EnsureCoresSealed(t) - - return cluster.RootToken, cluster.RecoveryKeys, transitSeal + return cluster, &opts, transitSeal } func runTransit( @@ -424,8 +714,9 @@ func runTransit( // Start the cluster var conf = vault.CoreConfig{ - Logger: logger.Named("runTransit"), - Seal: transitSeal, + Logger: logger.Named("runTransit"), + DisablePerformanceStandby: true, + Seal: transitSeal, } var opts = vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, @@ -449,17 +740,12 @@ func runTransit( // Unseal. Even though we are using autounseal, we have to unseal // explicitly because we are using SkipInit. if storage.IsRaft { - provider := testhelpers.NewHardcodedServerAddressProvider(cluster, baseClusterPort) - testhelpers.SetRaftAddressProviders(t, cluster, provider) - for _, core := range cluster.Cores { cluster.UnsealCoreWithStoredKeys(t, core) } - // This is apparently necessary for the raft cluster to get itself // situated. time.Sleep(15 * time.Second) - if err := testhelpers.VerifyRaftConfiguration(leader, len(cluster.Cores)); err != nil { t.Fatal(err) } @@ -483,35 +769,58 @@ func runTransit( cluster.EnsureCoresSealed(t) } -// TestShamir is a temporary test that exercises the reusable raft storage. -// It will be replace once we do the post-1.4 migration testing. -func TestShamir(t *testing.T) { - testVariousBackends(t, testShamir, true) -} +// joinRaftFollowers unseals the leader, and then joins-and-unseals the +// followers one at a time. We assume that the ServerAddressProvider has +// already been installed on all the nodes. +func joinRaftFollowers(t *testing.T, cluster *vault.TestCluster, useStoredKeys bool) { -func testShamir( - t *testing.T, logger hclog.Logger, - storage teststorage.ReusableStorage, basePort int) { + leader := cluster.Cores[0] - rootToken, barrierKeys := initializeShamir(t, logger, storage, basePort) - runShamir(t, logger, storage, basePort, rootToken, barrierKeys) -} + cluster.UnsealCore(t, leader) + vault.TestWaitActive(t, leader.Core) -// TestTransit is a temporary test that exercises the reusable raft storage. -// It will be replace once we do the post-1.4 migration testing. -func TestTransit(t *testing.T) { - testVariousBackends(t, testTransit, true) -} + leaderInfos := []*raft.LeaderJoinInfo{ + &raft.LeaderJoinInfo{ + LeaderAPIAddr: leader.Client.Address(), + TLSConfig: leader.TLSConfig, + }, + } -func testTransit( - t *testing.T, logger hclog.Logger, - storage teststorage.ReusableStorage, basePort int) { + // Join followers + for i := 1; i < len(cluster.Cores); i++ { + core := cluster.Cores[i] + _, err := core.JoinRaftCluster(namespace.RootContext(context.Background()), leaderInfos, false) + if err != nil { + t.Fatal(err) + } - // Create the transit server. - tss := sealhelper.NewTransitSealServer(t) - defer tss.Cleanup() - tss.MakeKey(t, "transit-seal-key") + if useStoredKeys { + // For autounseal, the raft backend is not initialized right away + // after the join. We need to wait briefly before we can unseal. + awaitUnsealWithStoredKeys(t, core) + } else { + cluster.UnsealCore(t, core) + } + } - rootToken, _, transitSeal := initializeTransit(t, logger, storage, basePort, tss) - runTransit(t, logger, storage, basePort, rootToken, transitSeal) + testhelpers.WaitForNCoresUnsealed(t, cluster, len(cluster.Cores)) +} + +func awaitUnsealWithStoredKeys(t *testing.T, core *vault.TestClusterCore) { + + timeout := time.Now().Add(30 * time.Second) + for { + if time.Now().After(timeout) { + t.Fatal("raft join: timeout waiting for core to unseal") + } + // Its actually ok for an error to happen here the first couple of + // times -- it means the raft join hasn't gotten around to initializing + // the backend yet. + err := core.UnsealWithStoredKeys(context.Background()) + if err == nil { + return + } + core.Logger().Warn("raft join: failed to unseal core", "error", err) + time.Sleep(time.Second) + } } diff --git a/vault/testing.go b/vault/testing.go index 0c73b1f29ae5..6f11fceea92b 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -27,20 +27,18 @@ import ( "time" "github.com/armon/go-metrics" - hclog "github.com/hashicorp/go-hclog" - log "github.com/hashicorp/go-hclog" - "github.com/hashicorp/vault/helper/metricsutil" - "github.com/hashicorp/vault/vault/cluster" - "github.com/hashicorp/vault/vault/seal" "github.com/mitchellh/copystructure" - + testing "github.com/mitchellh/go-testing-interface" "golang.org/x/crypto/ed25519" "golang.org/x/net/http2" cleanhttp "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" + raftlib "github.com/hashicorp/raft" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/reloadutil" @@ -52,9 +50,9 @@ import ( "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" - testing "github.com/mitchellh/go-testing-interface" - physInmem "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/hashicorp/vault/vault/cluster" + "github.com/hashicorp/vault/vault/seal" ) // This file contains a number of methods that are useful for unit @@ -724,6 +722,11 @@ type TestCluster struct { Logger log.Logger CleanupFunc func() SetupFunc func() + + cleanupFuncs []func() + base *CoreConfig + pubKey interface{} + priKey interface{} } func (c *TestCluster) Start() { @@ -836,6 +839,39 @@ func (c *TestClusterCore) Seal(t testing.T) { } } +func (c *TestClusterCore) stop() error { + + c.Logger().Info("stopping vault test core") + + if c.Listeners != nil { + for _, ln := range c.Listeners { + ln.Close() + } + c.Logger().Info("listeners successfully shut down") + } + if c.licensingStopCh != nil { + close(c.licensingStopCh) + c.licensingStopCh = nil + } + + if err := c.Shutdown(); err != nil { + return err + } + timeout := time.Now().Add(60 * time.Second) + for { + if time.Now().After(timeout) { + return errors.New("timeout waiting for core to seal") + } + if c.Sealed() { + break + } + time.Sleep(250 * time.Millisecond) + } + + c.Logger().Info("vault test core stopped") + return nil +} + func CleanupClusters(clusters []*TestCluster) { wg := &sync.WaitGroup{} for _, cluster := range clusters { @@ -855,7 +891,6 @@ func (c *TestCluster) Cleanup() { core.CoreConfig.Logger.SetLevel(log.Error) } - // Close listeners wg := &sync.WaitGroup{} for _, core := range c.Cores { wg.Add(1) @@ -863,29 +898,8 @@ func (c *TestCluster) Cleanup() { go func() { defer wg.Done() - if lc.Listeners != nil { - for _, ln := range lc.Listeners { - ln.Close() - } - } - if lc.licensingStopCh != nil { - close(lc.licensingStopCh) - lc.licensingStopCh = nil - } - - if err := lc.Shutdown(); err != nil { - lc.Logger().Error("error during shutdown; abandoning sealing", "error", err) - } else { - timeout := time.Now().Add(60 * time.Second) - for { - if time.Now().After(timeout) { - lc.Logger().Error("timeout waiting for core to seal") - } - if lc.Sealed() { - break - } - time.Sleep(250 * time.Millisecond) - } + if err := lc.stop(); err != nil { + lc.Logger().Error("error during cleanup", "error", err) } }() } @@ -937,6 +951,7 @@ type TestClusterCore struct { CoreConfig *CoreConfig Client *api.Client Handler http.Handler + Address *net.TCPAddr Listeners []*TestListener ReloadFuncs *map[string][]reloadutil.ReloadFunc ReloadFuncsLock *sync.RWMutex @@ -1000,7 +1015,7 @@ type TestClusterOptions struct { // core in cluster will have 0, second 1, etc. // If the backend is shared across the cluster (i.e. is not Raft) then it // should return nil when coreIdx != 0. - PhysicalFactory func(t testing.T, coreIdx int, logger hclog.Logger) *PhysicalBackendBundle + PhysicalFactory func(t testing.T, coreIdx int, logger log.Logger) *PhysicalBackendBundle // FirstCoreNumber is used to assign a unique number to each core within // a multi-cluster setup. FirstCoreNumber int @@ -1011,6 +1026,14 @@ type TestClusterOptions struct { // ClusterLayers are used to override the default cluster connection layer ClusterLayers cluster.NetworkLayerSet + + // RaftAddressProvider is used to set the raft ServerAddressProvider on + // each core. + // + // If SkipInit is true, then RaftAddressProvider has no effect. + // RaftAddressProvider should only be specified if the underlying physical + // storage is Raft. + RaftAddressProvider raftlib.ServerAddressProvider } var DefaultNumCores = 3 @@ -1024,7 +1047,7 @@ type certInfo struct { } type TestLogger struct { - hclog.Logger + log.Logger Path string File *os.File } @@ -1078,16 +1101,6 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te numCores = opts.NumCores } - var disablePR1103 bool - if opts != nil && opts.PR1103Disabled { - disablePR1103 = true - } - - var firstCoreNumber int - if opts != nil { - firstCoreNumber = opts.FirstCoreNumber - } - certIPs := []net.IP{ net.IPv6loopback, net.ParseIP("127.0.0.1"), @@ -1099,17 +1112,15 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatal("could not parse given base IP") } certIPs = append(certIPs, baseAddr.IP) - } - - baseClusterListenPort := 0 - if opts != nil && opts.BaseClusterListenPort != 0 { - if opts.BaseListenAddress == "" { - t.Fatal("BaseListenAddress is not specified") + } else { + baseAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, } - baseClusterListenPort = opts.BaseClusterListenPort } var testCluster TestCluster + testCluster.base = base switch { case opts != nil && opts.Logger != nil: @@ -1259,29 +1270,28 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te // // Listener setup // - ports := make([]int, numCores) - if baseAddr != nil { - for i := 0; i < numCores; i++ { - ports[i] = baseAddr.Port + i - } - } else { - baseAddr = &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - } - + addresses := []*net.TCPAddr{} listeners := [][]*TestListener{} servers := []*http.Server{} handlers := []http.Handler{} tlsConfigs := []*tls.Config{} certGetters := []*reloadutil.CertificateGetter{} for i := 0; i < numCores; i++ { - baseAddr.Port = ports[i] - ln, err := net.ListenTCP("tcp", baseAddr) + + addr := &net.TCPAddr{ + IP: baseAddr.IP, + Port: 0, + } + if baseAddr.Port != 0 { + addr.Port = baseAddr.Port + i + } + + ln, err := net.ListenTCP("tcp", addr) if err != nil { t.Fatal(err) } + addresses = append(addresses, addr) + certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) @@ -1450,94 +1460,289 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te if err != nil { t.Fatalf("err: %v", err) } + testCluster.pubKey = pubKey + testCluster.priKey = priKey - cleanupFuncs := []func(){} + // Create cores + testCluster.cleanupFuncs = []func(){} cores := []*Core{} coreConfigs := []*CoreConfig{} + for i := 0; i < numCores; i++ { - localConfig := *coreConfig - localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) - - // if opts.SealFunc is provided, use that to generate a seal for the config instead - if opts != nil && opts.SealFunc != nil { - localConfig.Seal = opts.SealFunc() - } - - if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) { - localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", i)) - } - if opts != nil && opts.PhysicalFactory != nil { - physBundle := opts.PhysicalFactory(t, i, localConfig.Logger) - switch { - case physBundle == nil && coreConfig.Physical != nil: - case physBundle == nil && coreConfig.Physical == nil: - t.Fatal("PhysicalFactory produced no physical and none in CoreConfig") - case physBundle != nil: - testCluster.Logger.Info("created physical backend", "instance", i) - coreConfig.Physical = physBundle.Backend - localConfig.Physical = physBundle.Backend - base.Physical = physBundle.Backend - haBackend := physBundle.HABackend - if haBackend == nil { - if ha, ok := physBundle.Backend.(physical.HABackend); ok { - haBackend = ha - } - } - coreConfig.HAPhysical = haBackend - localConfig.HAPhysical = haBackend - if physBundle.Cleanup != nil { - cleanupFuncs = append(cleanupFuncs, physBundle.Cleanup) - } - } + cleanup, c, localConfig, handler := testCluster.newCore(t, i, coreConfig, opts, listeners[i], pubKey) + + testCluster.cleanupFuncs = append(testCluster.cleanupFuncs, cleanup) + cores = append(cores, c) + coreConfigs = append(coreConfigs, &localConfig) + + if handler != nil { + handlers[i] = handler + servers[i].Handler = handlers[i] } + } + + // Clustering setup + for i := 0; i < numCores; i++ { + testCluster.setupClusterListener(t, i, cores[i], coreConfigs[i], opts, listeners[i], handlers[i]) + } + + // Create TestClusterCores + var ret []*TestClusterCore + for i := 0; i < numCores; i++ { - if opts != nil && opts.ClusterLayers != nil { - localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[i] + tcc := &TestClusterCore{ + Core: cores[i], + CoreConfig: coreConfigs[i], + ServerKey: certInfoSlice[i].key, + ServerKeyPEM: certInfoSlice[i].keyPEM, + ServerCert: certInfoSlice[i].cert, + ServerCertBytes: certInfoSlice[i].certBytes, + ServerCertPEM: certInfoSlice[i].certPEM, + Address: addresses[i], + Listeners: listeners[i], + Handler: handlers[i], + Server: servers[i], + TLSConfig: tlsConfigs[i], + Barrier: cores[i].barrier, + NodeID: fmt.Sprintf("core-%d", i), + UnderlyingRawStorage: coreConfigs[i].Physical, } + tcc.ReloadFuncs = &cores[i].reloadFuncs + tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock + tcc.ReloadFuncsLock.Lock() + (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} + tcc.ReloadFuncsLock.Unlock() - switch { - case localConfig.LicensingConfig != nil: - if pubKey != nil { - localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey)) + testAdjustTestCore(base, tcc) + + ret = append(ret, tcc) + } + testCluster.Cores = ret + + // Initialize cores + if opts == nil || !opts.SkipInit { + testCluster.initCores(t, opts, addAuditBackend) + } + + // Assign clients + for i := 0; i < numCores; i++ { + testCluster.Cores[i].Client = + testCluster.getAPIClient(t, opts, listeners[i][0].Address.Port, tlsConfigs[i]) + } + + // Extra Setup + for _, tcc := range testCluster.Cores { + testExtraTestCoreSetup(t, priKey, tcc) + } + + // Cleanup + testCluster.CleanupFunc = func() { + for _, c := range testCluster.cleanupFuncs { + c() + } + if l, ok := testCluster.Logger.(*TestLogger); ok { + if t.Failed() { + _ = l.File.Close() + } else { + _ = os.Remove(l.Path) } - default: - localConfig.LicensingConfig = testGetLicensingConfig(pubKey) } + } - if localConfig.MetricsHelper == nil { - inm := metrics.NewInmemSink(10*time.Second, time.Minute) - metrics.DefaultInmemSignal(inm) - localConfig.MetricsHelper = metricsutil.NewMetricsHelper(inm, false) + // Setup + if opts != nil { + if opts.SetupFunc != nil { + testCluster.SetupFunc = func() { + opts.SetupFunc(t, &testCluster) + } } + } - c, err := NewCore(&localConfig) - if err != nil { - t.Fatalf("err: %v", err) - } - c.coreNumber = firstCoreNumber + i - c.PR1103disabled = disablePR1103 - cores = append(cores, c) - coreConfigs = append(coreConfigs, &localConfig) - if opts != nil && opts.HandlerFunc != nil { - props := opts.DefaultHandlerProperties - props.Core = c - if props.ListenerConfig != nil && props.ListenerConfig.MaxRequestDuration == 0 { - props.ListenerConfig.MaxRequestDuration = DefaultMaxRequestDuration + return &testCluster +} + +// StopCore performs an orderly shutdown of a core. +func (cluster *TestCluster) StopCore(t testing.T, idx int) { + t.Helper() + + if idx < 0 || idx > len(cluster.Cores) { + t.Fatalf("invalid core index %d", idx) + } + tcc := cluster.Cores[idx] + tcc.Logger().Info("stopping core", "core", idx) + + // Stop listeners and call Shutdown() + if err := tcc.stop(); err != nil { + t.Fatal(err) + } + + // Run cleanup + cluster.cleanupFuncs[idx]() +} + +// Restart a TestClusterCore that was stopped, by replacing the +// underlying Core. +func (cluster *TestCluster) StartCore(t testing.T, idx int, opts *TestClusterOptions) { + t.Helper() + + if idx < 0 || idx > len(cluster.Cores) { + t.Fatalf("invalid core index %d", idx) + } + tcc := cluster.Cores[idx] + tcc.Logger().Info("restarting core", "core", idx) + + // Set up listeners + ln, err := net.ListenTCP("tcp", tcc.Address) + if err != nil { + t.Fatal(err) + } + tcc.Listeners = []*TestListener{&TestListener{ + Listener: tls.NewListener(ln, tcc.TLSConfig), + Address: ln.Addr().(*net.TCPAddr), + }, + } + + tcc.Handler = http.NewServeMux() + tcc.Server = &http.Server{ + Handler: tcc.Handler, + ErrorLog: cluster.Logger.StandardLogger(nil), + } + + // Create a new Core + cleanup, newCore, localConfig, coreHandler := cluster.newCore( + t, idx, tcc.CoreConfig, opts, tcc.Listeners, cluster.pubKey) + if coreHandler != nil { + tcc.Handler = coreHandler + tcc.Server.Handler = coreHandler + } + + cluster.cleanupFuncs[idx] = cleanup + tcc.Core = newCore + tcc.CoreConfig = &localConfig + tcc.UnderlyingRawStorage = localConfig.Physical + + cluster.setupClusterListener( + t, idx, newCore, tcc.CoreConfig, + opts, tcc.Listeners, tcc.Handler) + + tcc.Client = cluster.getAPIClient(t, opts, tcc.Listeners[0].Address.Port, tcc.TLSConfig) + + testAdjustTestCore(cluster.base, tcc) + testExtraTestCoreSetup(t, cluster.priKey, tcc) + + // Start listeners + for _, ln := range tcc.Listeners { + tcc.Logger().Info("starting listener for core", "port", ln.Address.Port) + go tcc.Server.Serve(ln) + } + + tcc.Logger().Info("restarted test core", "core", idx) +} + +func (testCluster *TestCluster) newCore( + t testing.T, idx int, coreConfig *CoreConfig, + opts *TestClusterOptions, listeners []*TestListener, pubKey interface{}, +) (func(), *Core, CoreConfig, http.Handler) { + + localConfig := *coreConfig + cleanupFunc := func() {} + var handler http.Handler + + var disablePR1103 bool + if opts != nil && opts.PR1103Disabled { + disablePR1103 = true + } + + var firstCoreNumber int + if opts != nil { + firstCoreNumber = opts.FirstCoreNumber + } + + localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[0].Address.Port) + + // if opts.SealFunc is provided, use that to generate a seal for the config instead + if opts != nil && opts.SealFunc != nil { + localConfig.Seal = opts.SealFunc() + } + + if coreConfig.Logger == nil || (opts != nil && opts.Logger != nil) { + localConfig.Logger = testCluster.Logger.Named(fmt.Sprintf("core%d", idx)) + } + if opts != nil && opts.PhysicalFactory != nil { + physBundle := opts.PhysicalFactory(t, idx, localConfig.Logger) + switch { + case physBundle == nil && coreConfig.Physical != nil: + case physBundle == nil && coreConfig.Physical == nil: + t.Fatal("PhysicalFactory produced no physical and none in CoreConfig") + case physBundle != nil: + testCluster.Logger.Info("created physical backend", "instance", idx) + coreConfig.Physical = physBundle.Backend + localConfig.Physical = physBundle.Backend + haBackend := physBundle.HABackend + if haBackend == nil { + if ha, ok := physBundle.Backend.(physical.HABackend); ok { + haBackend = ha + } + } + coreConfig.HAPhysical = haBackend + localConfig.HAPhysical = haBackend + if physBundle.Cleanup != nil { + cleanupFunc = physBundle.Cleanup } - handlers[i] = opts.HandlerFunc(&props) - servers[i].Handler = handlers[i] } + } - // Set this in case the Seal was manually set before the core was - // created - if localConfig.Seal != nil { - localConfig.Seal.SetCore(c) + if opts != nil && opts.ClusterLayers != nil { + localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[idx] + } + + switch { + case localConfig.LicensingConfig != nil: + if pubKey != nil { + localConfig.LicensingConfig.AdditionalPublicKeys = append(localConfig.LicensingConfig.AdditionalPublicKeys, pubKey.(ed25519.PublicKey)) } + default: + localConfig.LicensingConfig = testGetLicensingConfig(pubKey) + } + + if localConfig.MetricsHelper == nil { + inm := metrics.NewInmemSink(10*time.Second, time.Minute) + metrics.DefaultInmemSignal(inm) + localConfig.MetricsHelper = metricsutil.NewMetricsHelper(inm, false) + } + + c, err := NewCore(&localConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + c.coreNumber = firstCoreNumber + idx + c.PR1103disabled = disablePR1103 + if opts != nil && opts.HandlerFunc != nil { + props := opts.DefaultHandlerProperties + props.Core = c + if props.ListenerConfig != nil && props.ListenerConfig.MaxRequestDuration == 0 { + props.ListenerConfig.MaxRequestDuration = DefaultMaxRequestDuration + } + handler = opts.HandlerFunc(&props) + } + + // Set this in case the Seal was manually set before the core was + // created + if localConfig.Seal != nil { + localConfig.Seal.SetCore(c) + } + + return cleanupFunc, c, localConfig, handler +} + +func (testCluster *TestCluster) setupClusterListener( + t testing.T, idx int, core *Core, coreConfig *CoreConfig, + opts *TestClusterOptions, listeners []*TestListener, handler http.Handler) { + + if coreConfig.ClusterAddr == "" { + return } - // - // Clustering setup - // clusterAddrGen := func(lns []*TestListener, port int) []*net.TCPAddr { ret := make([]*net.TCPAddr, len(lns)) for i, ln := range lns { @@ -1549,251 +1754,207 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te return ret } - for i := 0; i < numCores; i++ { - if coreConfigs[i].ClusterAddr != "" { - port := 0 - if baseClusterListenPort != 0 { - port = baseClusterListenPort + i - } - cores[i].Logger().Info("assigning cluster listener for test core", "core", i, "port", port) - cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i], port)) - cores[i].SetClusterHandler(handlers[i]) + baseClusterListenPort := 0 + if opts != nil && opts.BaseClusterListenPort != 0 { + if opts.BaseListenAddress == "" { + t.Fatal("BaseListenAddress is not specified") } + baseClusterListenPort = opts.BaseClusterListenPort } - if opts == nil || !opts.SkipInit { - bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], handlers[0]) - barrierKeys, _ := copystructure.Copy(bKeys) - testCluster.BarrierKeys = barrierKeys.([][]byte) - recoveryKeys, _ := copystructure.Copy(rKeys) - testCluster.RecoveryKeys = recoveryKeys.([][]byte) - testCluster.RootToken = root - - // Write root token and barrier keys - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - for i, key := range testCluster.BarrierKeys { - buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) - if i < len(testCluster.BarrierKeys)-1 { - buf.WriteRune('\n') - } - } - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755) - if err != nil { - t.Fatal(err) - } - for i, key := range testCluster.RecoveryKeys { - buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) - if i < len(testCluster.RecoveryKeys)-1 { - buf.WriteRune('\n') - } + port := 0 + if baseClusterListenPort != 0 { + port = baseClusterListenPort + idx + } + core.Logger().Info("assigning cluster listener for test core", "core", idx, "port", port) + core.SetClusterListenerAddrs(clusterAddrGen(listeners, port)) + core.SetClusterHandler(handler) +} + +func (tc *TestCluster) initCores(t testing.T, opts *TestClusterOptions, addAuditBackend bool) { + + leader := tc.Cores[0] + + bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, leader.Core, leader.Handler) + barrierKeys, _ := copystructure.Copy(bKeys) + tc.BarrierKeys = barrierKeys.([][]byte) + recoveryKeys, _ := copystructure.Copy(rKeys) + tc.RecoveryKeys = recoveryKeys.([][]byte) + tc.RootToken = root + + // Write root token and barrier keys + err := ioutil.WriteFile(filepath.Join(tc.TempDir, "root_token"), []byte(root), 0755) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + for i, key := range tc.BarrierKeys { + buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) + if i < len(tc.BarrierKeys)-1 { + buf.WriteRune('\n') } - err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755) - if err != nil { - t.Fatal(err) + } + err = ioutil.WriteFile(filepath.Join(tc.TempDir, "barrier_keys"), buf.Bytes(), 0755) + if err != nil { + t.Fatal(err) + } + for i, key := range tc.RecoveryKeys { + buf.Write([]byte(base64.StdEncoding.EncodeToString(key))) + if i < len(tc.RecoveryKeys)-1 { + buf.WriteRune('\n') } + } + err = ioutil.WriteFile(filepath.Join(tc.TempDir, "recovery_keys"), buf.Bytes(), 0755) + if err != nil { + t.Fatal(err) + } - // Unseal first core - for _, key := range bKeys { - if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { - t.Fatalf("unseal err: %s", err) - } + // Unseal first core + for _, key := range bKeys { + if _, err := leader.Core.Unseal(TestKeyCopy(key)); err != nil { + t.Fatalf("unseal err: %s", err) } + } - ctx := context.Background() + ctx := context.Background() - // If stored keys is supported, the above will no no-op, so trigger auto-unseal - // using stored keys to try to unseal - if err := cores[0].UnsealWithStoredKeys(ctx); err != nil { - t.Fatal(err) - } + // If stored keys is supported, the above will no no-op, so trigger auto-unseal + // using stored keys to try to unseal + if err := leader.Core.UnsealWithStoredKeys(ctx); err != nil { + t.Fatal(err) + } - // Verify unsealed - if cores[0].Sealed() { - t.Fatal("should not be sealed") - } + // Verify unsealed + if leader.Core.Sealed() { + t.Fatal("should not be sealed") + } - TestWaitActive(t, cores[0]) + TestWaitActive(t, leader.Core) - // Existing tests rely on this; we can make a toggle to disable it - // later if we want - kvReq := &logical.Request{ - Operation: logical.UpdateOperation, - ClientToken: testCluster.RootToken, - Path: "sys/mounts/secret", - Data: map[string]interface{}{ - "type": "kv", - "path": "secret/", - "description": "key/value secret storage", - "options": map[string]string{ - "version": "1", - }, + // Existing tests rely on this; we can make a toggle to disable it + // later if we want + kvReq := &logical.Request{ + Operation: logical.UpdateOperation, + ClientToken: tc.RootToken, + Path: "sys/mounts/secret", + Data: map[string]interface{}{ + "type": "kv", + "path": "secret/", + "description": "key/value secret storage", + "options": map[string]string{ + "version": "1", }, - } - resp, err := cores[0].HandleRequest(namespace.RootContext(ctx), kvReq) - if err != nil { - t.Fatal(err) - } - if resp.IsError() { - t.Fatal(err) - } - - cfg, err := cores[0].seal.BarrierConfig(ctx) - if err != nil { - t.Fatal(err) - } + }, + } + resp, err := leader.Core.HandleRequest(namespace.RootContext(ctx), kvReq) + if err != nil { + t.Fatal(err) + } + if resp.IsError() { + t.Fatal(err) + } - // Unseal other cores unless otherwise specified - if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { - for i := 1; i < numCores; i++ { - cores[i].seal.SetCachedBarrierConfig(cfg) - for _, key := range bKeys { - if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { - t.Fatalf("unseal err: %s", err) - } - } + cfg, err := leader.Core.seal.BarrierConfig(ctx) + if err != nil { + t.Fatal(err) + } - // If stored keys is supported, the above will no no-op, so trigger auto-unseal - // using stored keys - if err := cores[i].UnsealWithStoredKeys(ctx); err != nil { - t.Fatal(err) + // Unseal other cores unless otherwise specified + numCores := len(tc.Cores) + if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { + for i := 1; i < numCores; i++ { + tc.Cores[i].Core.seal.SetCachedBarrierConfig(cfg) + for _, key := range bKeys { + if _, err := tc.Cores[i].Core.Unseal(TestKeyCopy(key)); err != nil { + t.Fatalf("unseal err: %s", err) } } - // Let them come fully up to standby - time.Sleep(2 * time.Second) - - // Ensure cluster connection info is populated. - // Other cores should not come up as leaders. - for i := 1; i < numCores; i++ { - isLeader, _, _, err := cores[i].Leader() - if err != nil { - t.Fatal(err) - } - if isLeader { - t.Fatalf("core[%d] should not be leader", i) - } + // If stored keys is supported, the above will no no-op, so trigger auto-unseal + // using stored keys + if err := tc.Cores[i].Core.UnsealWithStoredKeys(ctx); err != nil { + t.Fatal(err) } } - // - // Set test cluster core(s) and test cluster - // - cluster, err := cores[0].Cluster(context.Background()) - if err != nil { - t.Fatal(err) - } - testCluster.ID = cluster.ID - - if addAuditBackend { - // Enable auditing. - auditReq := &logical.Request{ - Operation: logical.UpdateOperation, - ClientToken: testCluster.RootToken, - Path: "sys/audit/noop", - Data: map[string]interface{}{ - "type": "noop", - }, - } - resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq) + // Let them come fully up to standby + time.Sleep(2 * time.Second) + + // Ensure cluster connection info is populated. + // Other cores should not come up as leaders. + for i := 1; i < numCores; i++ { + isLeader, _, _, err := tc.Cores[i].Core.Leader() if err != nil { t.Fatal(err) } - - if resp.IsError() { - t.Fatal(err) + if isLeader { + t.Fatalf("core[%d] should not be leader", i) } } } - getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client { - transport := cleanhttp.DefaultPooledTransport() - transport.TLSClientConfig = tlsConfig.Clone() - if err := http2.ConfigureTransport(transport); err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: transport, - CheckRedirect: func(*http.Request, []*http.Request) error { - // This can of course be overridden per-test by using its own client - return fmt.Errorf("redirects not allowed in these tests") + // + // Set test cluster core(s) and test cluster + // + cluster, err := leader.Core.Cluster(context.Background()) + if err != nil { + t.Fatal(err) + } + tc.ID = cluster.ID + + if addAuditBackend { + // Enable auditing. + auditReq := &logical.Request{ + Operation: logical.UpdateOperation, + ClientToken: tc.RootToken, + Path: "sys/audit/noop", + Data: map[string]interface{}{ + "type": "noop", }, } - config := api.DefaultConfig() - if config.Error != nil { - t.Fatal(config.Error) - } - config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) - config.HttpClient = client - config.MaxRetries = 0 - apiClient, err := api.NewClient(config) + resp, err = leader.Core.HandleRequest(namespace.RootContext(ctx), auditReq) if err != nil { t.Fatal(err) } - if opts == nil || !opts.SkipInit { - apiClient.SetToken(testCluster.RootToken) - } - return apiClient - } - var ret []*TestClusterCore - for i := 0; i < numCores; i++ { - tcc := &TestClusterCore{ - Core: cores[i], - CoreConfig: coreConfigs[i], - ServerKey: certInfoSlice[i].key, - ServerKeyPEM: certInfoSlice[i].keyPEM, - ServerCert: certInfoSlice[i].cert, - ServerCertBytes: certInfoSlice[i].certBytes, - ServerCertPEM: certInfoSlice[i].certPEM, - Listeners: listeners[i], - Handler: handlers[i], - Server: servers[i], - TLSConfig: tlsConfigs[i], - Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), - Barrier: cores[i].barrier, - NodeID: fmt.Sprintf("core-%d", i), - UnderlyingRawStorage: coreConfigs[i].Physical, + if resp.IsError() { + t.Fatal(err) } - tcc.ReloadFuncs = &cores[i].reloadFuncs - tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock - tcc.ReloadFuncsLock.Lock() - (*tcc.ReloadFuncs)["listener|tcp"] = []reloadutil.ReloadFunc{certGetters[i].Reload} - tcc.ReloadFuncsLock.Unlock() - - testAdjustTestCore(base, tcc) - - ret = append(ret, tcc) } - testCluster.Cores = ret +} - testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores) +func (testCluster *TestCluster) getAPIClient( + t testing.T, opts *TestClusterOptions, + port int, tlsConfig *tls.Config) *api.Client { - testCluster.CleanupFunc = func() { - for _, c := range cleanupFuncs { - c() - } - if l, ok := testCluster.Logger.(*TestLogger); ok { - if t.Failed() { - _ = l.File.Close() - } else { - _ = os.Remove(l.Path) - } - } + transport := cleanhttp.DefaultPooledTransport() + transport.TLSClientConfig = tlsConfig.Clone() + if err := http2.ConfigureTransport(transport); err != nil { + t.Fatal(err) } - if opts != nil { - if opts.SetupFunc != nil { - testCluster.SetupFunc = func() { - opts.SetupFunc(t, &testCluster) - } - } + client := &http.Client{ + Transport: transport, + CheckRedirect: func(*http.Request, []*http.Request) error { + // This can of course be overridden per-test by using its own client + return fmt.Errorf("redirects not allowed in these tests") + }, } - - return &testCluster + config := api.DefaultConfig() + if config.Error != nil { + t.Fatal(config.Error) + } + config.Address = fmt.Sprintf("https://127.0.0.1:%d", port) + config.HttpClient = client + config.MaxRetries = 0 + apiClient, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + if opts == nil || !opts.SkipInit { + apiClient.SetToken(testCluster.RootToken) + } + return apiClient } func NewMockBuiltinRegistry() *mockBuiltinRegistry { diff --git a/vault/testing_util.go b/vault/testing_util.go index 26c7cde057cc..0d1887298f74 100644 --- a/vault/testing_util.go +++ b/vault/testing_util.go @@ -6,9 +6,9 @@ import ( testing "github.com/mitchellh/go-testing-interface" ) -func testGenerateCoreKeys() (interface{}, interface{}, error) { return nil, nil, nil } -func testGetLicensingConfig(interface{}) *LicensingConfig { return &LicensingConfig{} } -func testExtraClusterCoresTestSetup(testing.T, interface{}, []*TestClusterCore) {} +func testGenerateCoreKeys() (interface{}, interface{}, error) { return nil, nil, nil } +func testGetLicensingConfig(interface{}) *LicensingConfig { return &LicensingConfig{} } +func testExtraTestCoreSetup(testing.T, interface{}, *TestClusterCore) {} func testAdjustTestCore(_ *CoreConfig, tcc *TestClusterCore) { tcc.UnderlyingStorage = tcc.physical }