diff --git a/vault/seal/azurekeyvault/azurekeyvault.go b/vault/seal/azurekeyvault/azurekeyvault.go index 59db4acfe784..0025adb325f5 100644 --- a/vault/seal/azurekeyvault/azurekeyvault.go +++ b/vault/seal/azurekeyvault/azurekeyvault.go @@ -124,10 +124,14 @@ func (v *AzureKeyVaultSeal) SetConfig(config map[string]string) (map[string]stri } // Test the client connection using provided key ID - _, err = client.GetKey(context.Background(), v.buildBaseURL(), v.keyName, "") + keyInfo, err := client.GetKey(context.Background(), v.buildBaseURL(), v.keyName, "") if err != nil { return nil, errwrap.Wrapf("error fetching Azure Key Vault seal key information: {{err}}", err) } + if keyInfo.Key == nil { + return nil, errors.New("no key information returned") + } + v.currentKeyID.Store(parseKeyVersion(to.String(keyInfo.Key.Kid))) v.client = client } @@ -185,10 +189,9 @@ func (v *AzureKeyVaultSeal) Encrypt(ctx context.Context, plaintext []byte) (*phy return nil, err } - // Kid gets returned as a full URL, get the last bit which is just - // the version - keyVersionParts := strings.Split(to.String(resp.Kid), "/") - keyVersion := keyVersionParts[len(keyVersionParts)-1] + // Store the current key version + keyVersion := parseKeyVersion(to.String(resp.Kid)) + v.currentKeyID.Store(keyVersion) ret := &physical.EncryptedBlobInfo{ Ciphertext: env.Ciphertext, @@ -265,3 +268,10 @@ func (v *AzureKeyVaultSeal) getKeyVaultClient() (*keyvault.BaseClient, error) { client.Authorizer = authorizer return &client, nil } + +// Kid gets returned as a full URL, get the last bit which is just +// the version +func parseKeyVersion(kid string) string { + keyVersionParts := strings.Split(kid, "/") + return keyVersionParts[len(keyVersionParts)-1] +} diff --git a/vault/seal/azurekeyvault/azurekeyvault_acc_test.go b/vault/seal/azurekeyvault/azurekeyvault_acc_test.go index c0feccb2842a..7e51ce23aacf 100644 --- a/vault/seal/azurekeyvault/azurekeyvault_acc_test.go +++ b/vault/seal/azurekeyvault/azurekeyvault_acc_test.go @@ -40,6 +40,10 @@ func TestAzureKeyVault_Lifecycle(t *testing.T) { } s := NewSeal(logging.NewVaultLogger(log.Trace)) + _, err := s.SetConfig(nil) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } // Test Encrypt and Decrypt calls input := []byte("foo")