Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

azurerm_cosmosdb_account: support for CMK through managed_hsm_key_id property #26521

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions internal/customermanagedkeys/key_vault_or_managed_hsm_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package customermanagedkeys

import (
"fmt"

"github.com/hashicorp/go-azure-sdk/sdk/environments"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/parse"
hsmParse "github.com/hashicorp/terraform-provider-azurerm/internal/services/managedhsm/parse"
"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
)

type VersionType int

const (
VersionTypeAny VersionType = iota
VersionTypeVersioned
VersionTypeVersionless
)

type KeyVaultOrManagedHSMKey struct {
KeyVaultKeyId *parse.NestedItemId
ManagedHSMKeyId *hsmParse.ManagedHSMDataPlaneVersionedKeyId
ManagedHSMKeyVersionlessId *hsmParse.ManagedHSMDataPlaneVersionlessKeyId
}

func (k *KeyVaultOrManagedHSMKey) IsSet() bool {
return k != nil && (k.KeyVaultKeyId != nil || k.ManagedHSMKeyId != nil || k.ManagedHSMKeyVersionlessId != nil)
}

func (k *KeyVaultOrManagedHSMKey) ID() string {
if k == nil {
return ""
}

if k.KeyVaultKeyId != nil {
return k.KeyVaultKeyId.ID()
}

if k.ManagedHSMKeyId != nil {
return k.ManagedHSMKeyId.ID()
}

if k.ManagedHSMKeyVersionlessId != nil {
return k.ManagedHSMKeyVersionlessId.ID()
}

return ""
}

func (k *KeyVaultOrManagedHSMKey) KeyVaultKeyID() string {
if k != nil && k.KeyVaultKeyId != nil {
return k.KeyVaultKeyId.ID()
}
return ""
}

func (k *KeyVaultOrManagedHSMKey) ManagedHSMKeyID() string {
if k != nil && k.ManagedHSMKeyId != nil {
return k.ManagedHSMKeyId.ID()
}

if k != nil && k.ManagedHSMKeyVersionlessId != nil {
return k.ManagedHSMKeyVersionlessId.ID()
}

return ""
}

func (k *KeyVaultOrManagedHSMKey) BaseUri() string {
if k.KeyVaultKeyId != nil {
return k.KeyVaultKeyId.KeyVaultBaseUrl
}

if k.ManagedHSMKeyId != nil {
return k.ManagedHSMKeyId.BaseUri()
}

if k.ManagedHSMKeyVersionlessId != nil {
return k.ManagedHSMKeyVersionlessId.BaseUri()
}

return ""
}

func parseKeyvaultID(keyRaw string, requireVersion VersionType, _ environments.Api) (*parse.NestedItemId, error) {
keyID, err := parse.ParseOptionallyVersionedNestedKeyID(keyRaw)
if err != nil {
return nil, err
}

if requireVersion == VersionTypeVersioned && keyID.Version == "" {
return nil, fmt.Errorf("expected a key vault versioned ID but no version information was found in: %q", keyRaw)
}

if requireVersion == VersionTypeVersionless && keyID.Version != "" {
return nil, fmt.Errorf("expected a key vault versionless ID but version information was found in: %q", keyRaw)
}

return keyID, nil
}

func parseManagedHSMKey(keyRaw string, requireVersion VersionType, hsmEnv environments.Api) (
versioned *hsmParse.ManagedHSMDataPlaneVersionedKeyId, versionless *hsmParse.ManagedHSMDataPlaneVersionlessKeyId, err error) {
// if specified with hasVersion == True, then it has to be parsed as versionedKeyID
var domainSuffix *string
if hsmEnv != nil {
domainSuffix, _ = hsmEnv.DomainSuffix()
}

switch requireVersion {
case VersionTypeAny:
if versioned, err = hsmParse.ManagedHSMDataPlaneVersionedKeyID(keyRaw, domainSuffix); err != nil {
if versionless, err = hsmParse.ManagedHSMDataPlaneVersionlessKeyID(keyRaw, domainSuffix); err != nil {
return nil, nil, fmt.Errorf("parse Managed HSM both versionedID and versionlessID err for %s", keyRaw)
}
}
case VersionTypeVersioned:
versioned, err = hsmParse.ManagedHSMDataPlaneVersionedKeyID(keyRaw, domainSuffix)
case VersionTypeVersionless:
versionless, err = hsmParse.ManagedHSMDataPlaneVersionlessKeyID(keyRaw, domainSuffix)
}

return versioned, versionless, err
}

func ExpandKeyVaultOrManagedHSMKey(d interface{}, requireVersion VersionType, keyVaultEnv, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
return ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d, requireVersion, "key_vault_key_id", "managed_hsm_key_id", keyVaultEnv, hsmEnv)
}

// ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey
// d: should be one of *pluginsdk.ResourceData or map[string]interface{}
// if return nil, nil, it means no key_vault_key_id or managed_hsm_key_id is specified
func ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(d interface{}, requireVersion VersionType, keyVaultFieldName, hsmFieldName string, keyVaultEnv, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
key := &KeyVaultOrManagedHSMKey{}
var err error
var vaultKeyStr, hsmKeyStr string
if rd, ok := d.(*pluginsdk.ResourceData); ok {
if keyRaw, ok := rd.GetOk(keyVaultFieldName); ok {
vaultKeyStr = keyRaw.(string)
magodo marked this conversation as resolved.
Show resolved Hide resolved
} else if keyRaw, ok = rd.GetOk(hsmFieldName); ok {
hsmKeyStr = keyRaw.(string)
magodo marked this conversation as resolved.
Show resolved Hide resolved
}
} else if obj, ok := d.(map[string]interface{}); ok {
if keyRaw, ok := obj[keyVaultFieldName]; ok {
vaultKeyStr, _ = keyRaw.(string)
}
if keyRaw, ok := obj[hsmFieldName]; ok {
hsmKeyStr, _ = keyRaw.(string)
}
} else {
return nil, fmt.Errorf("not supported data type to parse CMK: %T", d)
}

switch {
case vaultKeyStr != "":
if key.KeyVaultKeyId, err = parseKeyvaultID(vaultKeyStr, requireVersion, keyVaultEnv); err != nil {
return nil, err
}
case hsmKeyStr != "":
if key.ManagedHSMKeyId, key.ManagedHSMKeyVersionlessId, err = parseManagedHSMKey(hsmKeyStr, requireVersion, hsmEnv); err != nil {
return nil, err
}
default:
return nil, nil
}
return key, err
}

// FlattenKeyVaultOrManagedHSMID uses `KeyVaultOrManagedHSMKey.SetState()` to save the state, which this function is designed not to do.
func FlattenKeyVaultOrManagedHSMID(id string, hsmEnv environments.Api) (*KeyVaultOrManagedHSMKey, error) {
magodo marked this conversation as resolved.
Show resolved Hide resolved
if id == "" {
return nil, nil
}

key := &KeyVaultOrManagedHSMKey{}
var err error
key.KeyVaultKeyId, err = parse.ParseOptionallyVersionedNestedKeyID(id)
if err == nil {
return key, nil
}

var domainSuffix *string
if hsmEnv != nil {
domainSuffix, _ = hsmEnv.DomainSuffix()
}
if key.ManagedHSMKeyId, err = hsmParse.ManagedHSMDataPlaneVersionedKeyID(id, domainSuffix); err == nil {
return key, nil
}

if key.ManagedHSMKeyVersionlessId, err = hsmParse.ManagedHSMDataPlaneVersionlessKeyID(id, domainSuffix); err == nil {
return key, nil
}

return nil, fmt.Errorf("cannot parse given id to key vault key nor managed hsm key: %s", id)
}
133 changes: 133 additions & 0 deletions internal/customermanagedkeys/key_vault_or_managed_hsm_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package customermanagedkeys_test

import (
"reflect"
"testing"

"github.com/hashicorp/go-azure-sdk/sdk/environments"
"github.com/hashicorp/terraform-provider-azurerm/internal/customermanagedkeys"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/parse"
hsmParse "github.com/hashicorp/terraform-provider-azurerm/internal/services/managedhsm/parse"
)

func buildData(keyVaultKey, keyVualtValue, hsmKey, hsmValue string) interface{} {
data := map[string]interface{}{}
if keyVaultKey != "" {
data[keyVaultKey] = keyVualtValue
}

if hsmKey != "" {
data[hsmKey] = hsmValue
}

return data
}

func buildKeyVaultData(key, value string) interface{} {
return buildData(key, value, "", "")
}

func buildHSMData(key, value string) interface{} {
return buildData("", "", key, value)
}

func TestExpandKeyVaultOrManagedHSMKeyKey(t *testing.T) {
type args struct {
d interface{}
hasVersion customermanagedkeys.VersionType
keyVaultFieldName string
hsmFieldName string
hsmEnv environments.Api
}
tests := []struct {
name string
args args
want *customermanagedkeys.KeyVaultOrManagedHSMKey
wantErr bool
}{
{
name: "success with key_vault_key_id",
args: args{
d: buildKeyVaultData("key_vault_key_id", "https://test.keyvault.azure.net/keys/test-key-name"),
keyVaultFieldName: "key_vault_key_id",
},
want: &customermanagedkeys.KeyVaultOrManagedHSMKey{
KeyVaultKeyId: &parse.NestedItemId{
KeyVaultBaseUrl: "https://test.keyvault.azure.net/",
NestedItemType: "keys",
Name: "test-key-name",
},
},
},
{
name: "fail with wrong item type: cert",
args: args{
d: buildKeyVaultData("key_vault_key_id", "https://test.keyvault.azure.net/certs/test-key-name"),
keyVaultFieldName: "key_vault_key_id",
},
wantErr: true,
},
{
name: "fail with wrong field name",
args: args{
d: buildKeyVaultData("key_vault_key_url", "https://test.keyvault.azure.net/keys/test-key-name"),
keyVaultFieldName: "key_vault_key_id",
},
want: nil,
wantErr: false,
},
{
name: "fail with no version provided",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe we should fail if the version isn't specified as some resources take versionless ids

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also expand these tests to include versioned ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added new test case for versionless key vault id.

args: args{
d: buildKeyVaultData("key_vault_key_id", "https://test.keyvault.azure.net/keys/test-key-name3"),
keyVaultFieldName: "key_vault_key_id",
hasVersion: customermanagedkeys.VersionTypeVersioned,
},
want: nil,
wantErr: true,
},
{
name: "success with versionless key vault id",
args: args{
d: buildKeyVaultData("key_vault_key_id", "https://test.keyvault.azure.net/keys/test-key-versionless"),
keyVaultFieldName: "key_vault_key_id",
hasVersion: customermanagedkeys.VersionTypeVersionless,
},
want: &customermanagedkeys.KeyVaultOrManagedHSMKey{
KeyVaultKeyId: &parse.NestedItemId{
KeyVaultBaseUrl: "https://test.keyvault.azure.net/",
NestedItemType: "keys",
Name: "test-key-versionless",
},
},
wantErr: false,
},
{
name: "success with managed_hsm_key_id",
args: args{
d: buildHSMData("managed_hsm_key_id", "https://test.managedhsm.azure.net/keys/test-key-name"),
hsmFieldName: "managed_hsm_key_id",
hasVersion: customermanagedkeys.VersionTypeVersionless,
},
want: &customermanagedkeys.KeyVaultOrManagedHSMKey{
ManagedHSMKeyVersionlessId: &hsmParse.ManagedHSMDataPlaneVersionlessKeyId{
ManagedHSMName: "test",
DomainSuffix: "managedhsm.azure.net",
KeyName: "test-key-name",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t2 *testing.T) {
got, err := customermanagedkeys.ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey(tt.args.d, tt.args.hasVersion, tt.args.keyVaultFieldName, tt.args.hsmFieldName, nil, tt.args.hsmEnv)
if (err != nil) != tt.wantErr {
t2.Errorf("ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t2.Errorf("ExpandKeyVaultOrManagedHSMKeyWithCustomFieldKey() = %v, want %v", got, tt.want)
}
})
}
}
Loading
Loading