From c9de89d87be03d326c160cb6cfd470c649ae1272 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 16 Jan 2020 13:10:30 -0500 Subject: [PATCH] Pull wrapping creation to a var (#8137) * Pull a func out to a var * Funcouttovarextension (#8153) * Update test Co-authored-by: Lexman --- command/server/seal/server_seal_awskms.go | 10 ++++++++-- command/server/seal/server_seal_transit.go | 14 ++++++++++---- .../server/seal/server_seal_transit_acc_test.go | 9 ++++----- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/command/server/seal/server_seal_awskms.go b/command/server/seal/server_seal_awskms.go index 2d5b71ea583e..173fbaadf7e2 100644 --- a/command/server/seal/server_seal_awskms.go +++ b/command/server/seal/server_seal_awskms.go @@ -3,6 +3,7 @@ package seal import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-hclog" + wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/go-kms-wrapping/wrappers/awskms" "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/sdk/logical" @@ -10,9 +11,14 @@ import ( "github.com/hashicorp/vault/vault/seal" ) -func configureAWSKMSSeal(configSeal *server.Seal, infoKeys *[]string, info *map[string]string, logger hclog.Logger, inseal vault.Seal) (vault.Seal, error) { +var getAWSKMSFunc = func(opts *wrapping.WrapperOptions, config map[string]string) (wrapping.Wrapper, map[string]string, error) { kms := awskms.NewWrapper(nil) - kmsInfo, err := kms.SetConfig(configSeal.Config) + kmsInfo, err := kms.SetConfig(config) + return kms, kmsInfo, err +} + +func configureAWSKMSSeal(configSeal *server.Seal, infoKeys *[]string, info *map[string]string, logger hclog.Logger, inseal vault.Seal) (vault.Seal, error) { + kms, kmsInfo, err := getAWSKMSFunc(nil, configSeal.Config) if err != nil { // If the error is any other than logical.KeyNotFoundError, return the error if !errwrap.ContainsType(err, new(logical.KeyNotFoundError)) { diff --git a/command/server/seal/server_seal_transit.go b/command/server/seal/server_seal_transit.go index 0a9bc1e7a689..e8838d98a3a1 100644 --- a/command/server/seal/server_seal_transit.go +++ b/command/server/seal/server_seal_transit.go @@ -11,11 +11,17 @@ import ( "github.com/hashicorp/vault/vault/seal" ) +var GetTransitKMSFunc = func(opts *wrapping.WrapperOptions, config map[string]string) (wrapping.Wrapper, map[string]string, error) { + transitSeal := transit.NewWrapper(opts) + sealInfo, err := transitSeal.SetConfig(config) + return transitSeal, sealInfo, err +} + func configureTransitSeal(configSeal *server.Seal, infoKeys *[]string, info *map[string]string, logger log.Logger, inseal vault.Seal) (vault.Seal, error) { - transitSeal := transit.NewWrapper(&wrapping.WrapperOptions{ - Logger: logger.ResetNamed("seal-transit"), - }) - sealInfo, err := transitSeal.SetConfig(configSeal.Config) + transitSeal, sealInfo, err := GetTransitKMSFunc( + &wrapping.WrapperOptions{ + Logger: logger.ResetNamed("seal-transit"), + }, configSeal.Config) if err != nil { // If the error is any other than logical.KeyNotFoundError, return the error if !errwrap.ContainsType(err, new(logical.KeyNotFoundError)) { diff --git a/command/server/seal/server_seal_transit_acc_test.go b/command/server/seal/server_seal_transit_acc_test.go index 5b469c04109c..2c13958fa775 100644 --- a/command/server/seal/server_seal_transit_acc_test.go +++ b/command/server/seal/server_seal_transit_acc_test.go @@ -10,9 +10,9 @@ import ( "testing" "time" - "github.com/hashicorp/go-kms-wrapping/wrappers/transit" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/server/seal" "github.com/ory/dockertest" ) @@ -29,8 +29,8 @@ func TestTransitWrapper_Lifecycle(t *testing.T) { "mount_path": mountPath, "key_name": keyName, } - s := transit.NewWrapper(nil) - _, err := s.SetConfig(wrapperConfig) + + s, _, err := seal.GetTransitKMSFunc(nil, wrapperConfig) if err != nil { t.Fatalf("error setting wrapper config: %v", err) } @@ -86,8 +86,7 @@ func TestTransitSeal_TokenRenewal(t *testing.T) { "mount_path": mountPath, "key_name": keyName, } - s := transit.NewWrapper(nil) - _, err = s.SetConfig(wrapperConfig) + s, _, err := seal.GetTransitKMSFunc(nil, wrapperConfig) if err != nil { t.Fatalf("error setting wrapper config: %v", err) }