Skip to content

Commit

Permalink
feat: nversion support for azkv
Browse files Browse the repository at this point in the history
  • Loading branch information
duffney committed Nov 26, 2024
1 parent 9b5d31b commit 935681c
Show file tree
Hide file tree
Showing 3 changed files with 389 additions and 127 deletions.
151 changes: 97 additions & 54 deletions pkg/keymanagementprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ import (
"golang.org/x/crypto/pkcs12"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
)

const (
ProviderName string = "azurekeyvault"
PKCS12ContentType string = "application/x-pkcs12"
PEMContentType string = "application/x-pem-file"
ProviderName string = "azurekeyvault"
PKCS12ContentType string = "application/x-pkcs12"
PEMContentType string = "application/x-pem-file"
versionHistoryLimitDefault int = 1
)

var logOpt = logger.Option{
Expand Down Expand Up @@ -88,6 +90,8 @@ type akvKMProviderFactory struct{}
type keyKVClient interface {
// GetKey retrieves a key from the keyvault
GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error)
// NewListKeyVersionsPager retrieves a pager for listing key versions
NewListKeyVersionsPager(name string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse]
}
type secretKVClient interface {
// GetSecret retrieves a secret from the keyvault
Expand All @@ -96,6 +100,8 @@ type secretKVClient interface {
type certificateKVClient interface {
// GetCertificate retrieves a certificate from the keyvault
GetCertificate(ctx context.Context, certificateName string, certificateVersion string) (azcertificates.GetCertificateResponse, error)
// NewListCertificateVersionsPager creates a new instance of the ListCertificateVersionsPager
NewListCertificateVersionsPager(certificateName string, options *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse]
}

type keyKVClientImpl struct {
Expand All @@ -113,11 +119,20 @@ func (c *certificateKVClientImpl) GetCertificate(ctx context.Context, certificat
return c.Client.GetCertificate(ctx, certificateName, certificateVersion, nil)
}

// NewListCertificateVersionsPager retrieves a pager for listing certificate versions
func (c *certificateKVClientImpl) NewListCertificateVersionsPager(certificateName string, options *azcertificates.ListCertificateVersionsOptions) *runtime.Pager[azcertificates.ListCertificateVersionsResponse] {
return c.Client.NewListCertificateVersionsPager(certificateName, options)
}

// GetKey retrieves a key from the keyvault
func (c *keyKVClientImpl) GetKey(ctx context.Context, keyName string, keyVersion string) (azkeys.GetKeyResponse, error) {
return c.Client.GetKey(ctx, keyName, keyVersion, nil)
}

func (c *keyKVClientImpl) NewListKeyVersionsPager(name string, options *azkeys.ListKeyVersionsOptions) *runtime.Pager[azkeys.ListKeyVersionsResponse] {
return c.Client.NewListKeyVersionsPager(name, options)
}

// GetSecret retrieves a secret from the keyvault
func (c *secretKVClientImpl) GetSecret(ctx context.Context, secretName string, secretVersion string) (azsecrets.GetSecretResponse, error) {
return c.Client.GetSecret(ctx, secretName, secretVersion, nil)
Expand Down Expand Up @@ -186,39 +201,51 @@ func (s *akvKMProvider) GetCertificates(ctx context.Context) (map[keymanagementp
logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, certVersion %v, vaultURI: %v", keyVaultCert.Name, keyVaultCert.Version, s.vaultURI)

startTime := time.Now()
secretResponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, keyVaultCert.Version)
if err != nil {
if isSecretDisabledError(err) {
// if secret is disabled, get the version of the certificate for status
certResponse, err := s.certificateKVClient.GetCertificate(ctx, keyVaultCert.Name, keyVaultCert.Version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificate objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}
certBundle := certResponse.CertificateBundle
keyVaultCert.Version = getObjectVersion(*certBundle.KID)
isEnabled := *certBundle.Attributes.Enabled
lastRefreshed := startTime.Format(time.RFC3339)
certProperty := getStatusProperty(keyVaultCert.Name, keyVaultCert.Version, lastRefreshed, isEnabled)
certsStatus = append(certsStatus, certProperty)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled}
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultCert.VersionHistoryLimit == 0 {
keyVaultCert.VersionHistoryLimit = versionHistoryLimitDefault
}

versionHistory := []string{}
certVersionPager := s.certificateKVClient.NewListCertificateVersionsPager(keyVaultCert.Name, nil)
for certVersionPager.More() {
pager, err := certVersionPager.NextPage(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificate versions for objectName:%s, error: %w", keyVaultCert.Name, err)
}
for _, cert := range pager.Value {
versionHistory = append(versionHistory, cert.ID.Version())
}
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, keyVaultCert.Version, err)
}

secretBundle := secretResponse.SecretBundle
isEnabled := *secretBundle.Attributes.Enabled
for _, version := range versionHistory[:keyVaultCert.VersionHistoryLimit] {
secretReponse, err := s.secretKVClient.GetSecret(ctx, keyVaultCert.Name, version)
if err != nil {
if isSecretDisabledError(err) {
isEnabled := false
lastRefreshed := startTime.Format(time.RFC3339)
certProperty := getStatusProperty(keyVaultCert.Name, version, lastRefreshed, isEnabled)
certsStatus = append(certsStatus, certProperty)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: version, Enabled: isEnabled}
keymanagementprovider.DeleteCertificateFromMap(s.resource, mapKey)
continue
}
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.Name, version, err)
}

certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificates from secret bundle:%w", err)
}
secretBundle := secretReponse.SecretBundle
isEnabled := *secretBundle.Attributes.Enabled

metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultCert.Name)
certsStatus = append(certsStatus, certProperty...)
certMapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: keyVaultCert.Version, Enabled: isEnabled}
certsMap[certMapKey] = certResult
certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.Name, isEnabled)
if err != nil {
return nil, nil, fmt.Errorf("failed to get certificates from secret bundle:%w", err)
}

metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultCert.Name)
certsStatus = append(certsStatus, certProperty...)
certMapKey := keymanagementprovider.KMPMapKey{Name: keyVaultCert.Name, Version: version, Enabled: isEnabled}
certsMap[certMapKey] = certResult
}
}
return certsMap, getStatusMap(certsStatus, types.CertificatesStatus), nil
}
Expand All @@ -233,33 +260,49 @@ func (s *akvKMProvider) GetKeys(ctx context.Context) (map[keymanagementprovider.

// fetch the key object from Key Vault
startTime := time.Now()
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, keyVaultKey.Version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, keyVaultKey.Version, err)
// if versionHistoryLimit is not set, set it to default value 1
if keyVaultKey.VersionHistoryLimit == 0 {
keyVaultKey.VersionHistoryLimit = versionHistoryLimitDefault
}
keyBundle := keyResponse.KeyBundle
isEnabled := *keyBundle.Attributes.Enabled
// if version is set as "" in the config, use the version from the key bundle
keyVaultKey.Version = getObjectVersion(string(*keyBundle.Key.KID))

if !isEnabled {
startTime := time.Now()
lastRefreshed := startTime.Format(time.RFC3339)
properties := getStatusProperty(keyVaultKey.Name, keyVaultKey.Version, lastRefreshed, isEnabled)
keysStatus = append(keysStatus, properties)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: keyVaultKey.Version, Enabled: isEnabled}
keymanagementprovider.DeleteKeyFromMap(s.resource, mapKey)
continue

versionHistory := []string{}
keyVersionPager := s.keyKVClient.NewListKeyVersionsPager(keyVaultKey.Name, nil)
for keyVersionPager.More() {
pager, err := keyVersionPager.NextPage(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key versions for objectName:%s, error: %w", keyVaultKey.Name, err)
}
for _, key := range pager.Value {
versionHistory = append(versionHistory, key.KID.Version())
}
}

publicKey, err := getKeyFromKeyBundle(keyBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from key bundle:%w", err)
for _, version := range versionHistory[:keyVaultKey.VersionHistoryLimit] {
keyResponse, err := s.keyKVClient.GetKey(ctx, keyVaultKey.Name, version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key objectName:%s, objectVersion:%s, error: %w", keyVaultKey.Name, version, err)
}
keyBundle := keyResponse.KeyBundle
isEnabled := *keyBundle.Attributes.Enabled

if !isEnabled {
lastRefresh := time.Now().Format(time.RFC3339)
keyProperties := getStatusProperty(keyVaultKey.Name, version, lastRefresh, isEnabled)
keysStatus = append(keysStatus, keyProperties)
mapKey := keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: version, Enabled: isEnabled}
keymanagementprovider.DeleteKeyFromMap(s.resource, mapKey)
continue
}

publicKey, err := getKeyFromKeyBundle(keyBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from key bundle:%w", err)
}
keysMap[keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: version, Enabled: isEnabled}] = publicKey
metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultKey.Name)
keyProperties := getStatusProperty(keyVaultKey.Name, version, time.Now().Format(time.RFC3339), isEnabled)
keysStatus = append(keysStatus, keyProperties)
}
keysMap[keymanagementprovider.KMPMapKey{Name: keyVaultKey.Name, Version: keyVaultKey.Version, Enabled: isEnabled}] = publicKey
metrics.ReportAKVCertificateDuration(ctx, time.Since(startTime).Milliseconds(), keyVaultKey.Name)
properties := getStatusProperty(keyVaultKey.Name, keyVaultKey.Version, time.Now().Format(time.RFC3339), isEnabled)
keysStatus = append(keysStatus, properties)
}

return keysMap, getStatusMap(keysStatus, types.KeysStatus), nil
Expand Down
Loading

0 comments on commit 935681c

Please sign in to comment.